1use crate::protocol::{ClientBackend, ProtocolUnitInfo, Resolved, ServerBackend};
4use crate::types::RelayError;
5use crate::utils::NetworkPeerHandle;
6use crate::LOG_TARGET;
7use derive_more::From;
8use parity_scale_codec::{Decode, Encode};
9use std::collections::BTreeMap;
10use tracing::{trace, warn};
11
12#[derive(From, Encode, Decode)]
15pub(crate) enum CompactBlockInitialRequest {
16 #[codec(index = 0)]
17 V0,
18 }
21
22#[derive(Encode, Decode)]
24pub(crate) struct CompactBlockInitialResponse<DownloadUnitId, ProtocolUnitId, ProtocolUnit> {
25 download_unit_id: DownloadUnitId,
27
28 protocol_units: Vec<ProtocolUnitInfo<ProtocolUnitId, ProtocolUnit>>,
30}
31
32#[derive(From, Encode, Decode)]
34pub(crate) enum CompactBlockHandshake<DownloadUnitId, ProtocolUnitId> {
35 #[codec(index = 0)]
37 MissingEntriesV0(MissingEntriesRequest<DownloadUnitId, ProtocolUnitId>),
38 }
41
42#[derive(From, Encode, Decode)]
44pub(crate) enum CompactBlockHandshakeResponse<ProtocolUnit> {
45 #[codec(index = 0)]
47 MissingEntriesV0(MissingEntriesResponse<ProtocolUnit>),
48 }
51
52#[derive(Encode, Decode)]
54pub(crate) struct MissingEntriesRequest<DownloadUnitId, ProtocolUnitId> {
55 download_unit_id: DownloadUnitId,
57
58 protocol_unit_ids: BTreeMap<u64, ProtocolUnitId>,
62}
63
64#[derive(Encode, Decode)]
66pub(crate) struct MissingEntriesResponse<ProtocolUnit> {
67 protocol_units: BTreeMap<u64, ProtocolUnit>,
69}
70
71struct ResolveContext<ProtocolUnitId, ProtocolUnit> {
72 resolved: BTreeMap<u64, Resolved<ProtocolUnitId, ProtocolUnit>>,
73 local_miss: BTreeMap<u64, ProtocolUnitId>,
74}
75
76pub(crate) struct CompactBlockClient<DownloadUnitId, ProtocolUnitId, ProtocolUnit> {
77 _phantom_data: std::marker::PhantomData<(DownloadUnitId, ProtocolUnitId, ProtocolUnit)>,
78}
79
80impl<DownloadUnitId, ProtocolUnitId, ProtocolUnit>
81 CompactBlockClient<DownloadUnitId, ProtocolUnitId, ProtocolUnit>
82where
83 DownloadUnitId: Send + Sync + Encode + Decode + Clone + std::fmt::Debug,
84 ProtocolUnitId: Send + Sync + Encode + Decode + Clone,
85 ProtocolUnit: Send + Sync + Encode + Decode + Clone,
86{
87 pub(crate) fn new() -> Self {
89 Self {
90 _phantom_data: Default::default(),
91 }
92 }
93
94 pub(crate) fn build_initial_request(
96 &self,
97 _backend: &dyn ClientBackend<ProtocolUnitId, ProtocolUnit>,
98 ) -> CompactBlockInitialRequest {
99 CompactBlockInitialRequest::V0
100 }
101
102 pub(crate) async fn resolve_initial_response<Request>(
104 &self,
105 compact_response: CompactBlockInitialResponse<DownloadUnitId, ProtocolUnitId, ProtocolUnit>,
106 network_peer_handle: &NetworkPeerHandle,
107 backend: &dyn ClientBackend<ProtocolUnitId, ProtocolUnit>,
108 ) -> Result<(DownloadUnitId, Vec<Resolved<ProtocolUnitId, ProtocolUnit>>), RelayError>
109 where
110 Request: From<CompactBlockHandshake<DownloadUnitId, ProtocolUnitId>> + Encode + Send + Sync,
111 {
112 let context = self.resolve_local(&compact_response, backend)?;
114 if context.resolved.len() == compact_response.protocol_units.len() {
115 trace!(
116 target: LOG_TARGET,
117 "relay::resolve: {:?}: resolved locally[{}]",
118 compact_response.download_unit_id,
119 compact_response.protocol_units.len()
120 );
121 return Ok((
122 compact_response.download_unit_id,
123 context.resolved.into_values().collect(),
124 ));
125 }
126
127 let misses = context.local_miss.len();
129 let download_unit_id = compact_response.download_unit_id.clone();
130 let resolved = self
131 .resolve_misses::<Request>(compact_response, context, network_peer_handle)
132 .await?;
133 trace!(
134 target: LOG_TARGET,
135 "relay::resolve: {:?}: resolved by server[{},{}]",
136 download_unit_id,
137 resolved.len(),
138 misses,
139 );
140 Ok((download_unit_id, resolved))
141 }
142
143 fn resolve_local(
145 &self,
146 compact_response: &CompactBlockInitialResponse<
147 DownloadUnitId,
148 ProtocolUnitId,
149 ProtocolUnit,
150 >,
151 backend: &dyn ClientBackend<ProtocolUnitId, ProtocolUnit>,
152 ) -> Result<ResolveContext<ProtocolUnitId, ProtocolUnit>, RelayError> {
153 let mut context = ResolveContext {
154 resolved: BTreeMap::new(),
155 local_miss: BTreeMap::new(),
156 };
157
158 for (index, entry) in compact_response.protocol_units.iter().enumerate() {
159 let ProtocolUnitInfo { id, unit } = entry;
160 if let Some(unit) = unit {
161 context.resolved.insert(
163 index as u64,
164 Resolved {
165 protocol_unit_id: id.clone(),
166 protocol_unit: unit.clone(),
167 locally_resolved: true,
168 },
169 );
170 continue;
171 }
172
173 match backend.protocol_unit(id) {
174 Some(ret) => {
175 context.resolved.insert(
176 index as u64,
177 Resolved {
178 protocol_unit_id: id.clone(),
179 protocol_unit: ret,
180 locally_resolved: true,
181 },
182 );
183 }
184 None => {
185 context.local_miss.insert(index as u64, id.clone());
186 }
187 }
188 }
189
190 Ok(context)
191 }
192
193 async fn resolve_misses<Request>(
195 &self,
196 compact_response: CompactBlockInitialResponse<DownloadUnitId, ProtocolUnitId, ProtocolUnit>,
197 context: ResolveContext<ProtocolUnitId, ProtocolUnit>,
198 network_peer_handle: &NetworkPeerHandle,
199 ) -> Result<Vec<Resolved<ProtocolUnitId, ProtocolUnit>>, RelayError>
200 where
201 Request: From<CompactBlockHandshake<DownloadUnitId, ProtocolUnitId>> + Encode + Send + Sync,
202 {
203 let ResolveContext {
204 mut resolved,
205 local_miss,
206 } = context;
207 let missing = local_miss.len();
208 let request = CompactBlockHandshake::from(MissingEntriesRequest {
210 download_unit_id: compact_response.download_unit_id.clone(),
211 protocol_unit_ids: local_miss.clone(),
212 });
213
214 let response: CompactBlockHandshakeResponse<ProtocolUnit> =
215 network_peer_handle.request(Request::from(request)).await?;
216 let CompactBlockHandshakeResponse::MissingEntriesV0(missing_entries_response) = response;
217
218 if missing_entries_response.protocol_units.len() != missing {
219 return Err(RelayError::ResolveMismatch {
220 expected: missing,
221 actual: missing_entries_response.protocol_units.len(),
222 });
223 }
224
225 for (missing_key, protocol_unit_id) in local_miss.into_iter() {
227 if let Some(protocol_unit) = missing_entries_response.protocol_units.get(&missing_key) {
228 resolved.insert(
229 missing_key,
230 Resolved {
231 protocol_unit_id,
232 protocol_unit: protocol_unit.clone(),
233 locally_resolved: false,
234 },
235 );
236 } else {
237 return Err(RelayError::ResolvedNotFound(missing));
238 }
239 }
240
241 Ok(resolved.into_values().collect())
242 }
243}
244
245pub(crate) struct CompactBlockServer<DownloadUnitId, ProtocolUnitId, ProtocolUnit> {
246 _phantom_data: std::marker::PhantomData<(DownloadUnitId, ProtocolUnitId, ProtocolUnit)>,
247}
248
249impl<DownloadUnitId, ProtocolUnitId, ProtocolUnit>
250 CompactBlockServer<DownloadUnitId, ProtocolUnitId, ProtocolUnit>
251where
252 DownloadUnitId: Encode + Decode + Clone,
253 ProtocolUnitId: Encode + Decode + Clone,
254 ProtocolUnit: Encode + Decode,
255{
256 pub(crate) fn new() -> Self {
258 Self {
259 _phantom_data: Default::default(),
260 }
261 }
262
263 pub(crate) fn build_initial_response(
265 &self,
266 download_unit_id: &DownloadUnitId,
267 _initial_request: CompactBlockInitialRequest,
268 backend: &dyn ServerBackend<DownloadUnitId, ProtocolUnitId, ProtocolUnit>,
269 ) -> Result<CompactBlockInitialResponse<DownloadUnitId, ProtocolUnitId, ProtocolUnit>, RelayError>
270 {
271 Ok(CompactBlockInitialResponse {
273 download_unit_id: download_unit_id.clone(),
274 protocol_units: backend.download_unit_members(download_unit_id)?,
275 })
276 }
277
278 pub(crate) fn on_protocol_message(
280 &self,
281 message: CompactBlockHandshake<DownloadUnitId, ProtocolUnitId>,
282 backend: &dyn ServerBackend<DownloadUnitId, ProtocolUnitId, ProtocolUnit>,
283 ) -> Result<CompactBlockHandshakeResponse<ProtocolUnit>, RelayError> {
284 let CompactBlockHandshake::MissingEntriesV0(request) = message;
285
286 let mut protocol_units = BTreeMap::new();
287 let total_len = request.protocol_unit_ids.len();
288 for (missing_id, protocol_unit_id) in request.protocol_unit_ids {
289 if let Some(protocol_unit) =
290 backend.protocol_unit(&request.download_unit_id, &protocol_unit_id)
291 {
292 protocol_units.insert(missing_id, protocol_unit);
293 } else {
294 warn!(
295 target: LOG_TARGET,
296 "relay::on_request: missing entry not found"
297 );
298 }
299 }
300 if total_len != protocol_units.len() {
301 warn!(
302 target: LOG_TARGET,
303 "relay::compact_blocks::on_request: could not resolve all entries: {total_len}/{}",
304 protocol_units.len()
305 );
306 }
307 Ok(CompactBlockHandshakeResponse::from(
308 MissingEntriesResponse { protocol_units },
309 ))
310 }
311}