1#[cfg(test)]
2mod tests;
3
4use crate::protocols::request_response::handlers::generic_request_handler::GenericRequest;
5use crate::protocols::request_response::request_response_factory;
6use crate::shared::{Command, CreatedSubscription, PeerDiscovered, Shared};
7use crate::utils::multihash::Multihash;
8use crate::utils::HandlerFn;
9use bytes::Bytes;
10use event_listener_primitives::HandlerId;
11use futures::channel::{mpsc, oneshot};
12use futures::{SinkExt, Stream, StreamExt};
13use libp2p::gossipsub::{Sha256Topic, SubscriptionError};
14use libp2p::kad::{PeerRecord, RecordKey};
15use libp2p::{Multiaddr, PeerId};
16use parity_scale_codec::Decode;
17use std::pin::Pin;
18use std::sync::{Arc, Weak};
19use std::task::{Context, Poll};
20use thiserror::Error;
21use tokio::sync::OwnedSemaphorePermit;
22use tracing::{debug, error, trace};
23
24#[derive(Debug)]
26#[pin_project::pin_project(PinnedDrop)]
27pub struct TopicSubscription {
28 topic: Option<Sha256Topic>,
29 subscription_id: usize,
30 command_sender: Option<mpsc::Sender<Command>>,
31 #[pin]
32 receiver: mpsc::UnboundedReceiver<Bytes>,
33 _permit: OwnedSemaphorePermit,
34}
35
36impl Stream for TopicSubscription {
37 type Item = Bytes;
38 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
39 self.project().receiver.poll_next(cx)
40 }
41 fn size_hint(&self) -> (usize, Option<usize>) {
42 self.receiver.size_hint()
43 }
44}
45
46#[pin_project::pinned_drop]
47impl PinnedDrop for TopicSubscription {
48 fn drop(mut self: Pin<&mut Self>) {
49 let topic = self
50 .topic
51 .take()
52 .expect("Always specified on creation and only removed on drop; qed");
53 let subscription_id = self.subscription_id;
54 let mut command_sender = self
55 .command_sender
56 .take()
57 .expect("Always specified on creation and only removed on drop; qed");
58
59 tokio::spawn(async move {
60 let _ = command_sender
62 .send(Command::Unsubscribe {
63 topic,
64 subscription_id,
65 })
66 .await;
67 });
68 }
69}
70
71#[derive(Debug, Error)]
72pub enum GetValueError {
73 #[error("Failed to send command to the node runner: {0}")]
75 SendCommand(#[from] mpsc::SendError),
76 #[error("Node runner was dropped")]
78 NodeRunnerDropped,
79}
80
81impl From<oneshot::Canceled> for GetValueError {
82 #[inline]
83 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
84 Self::NodeRunnerDropped
85 }
86}
87
88#[derive(Debug, Error)]
89pub enum PutValueError {
90 #[error("Failed to send command to the node runner: {0}")]
92 SendCommand(#[from] mpsc::SendError),
93 #[error("Node runner was dropped")]
95 NodeRunnerDropped,
96}
97
98impl From<oneshot::Canceled> for PutValueError {
99 #[inline]
100 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
101 Self::NodeRunnerDropped
102 }
103}
104
105#[derive(Debug, Error)]
107pub enum GetClosestPeersError {
108 #[error("Failed to send command to the node runner: {0}")]
110 SendCommand(#[from] mpsc::SendError),
111 #[error("Node runner was dropped")]
113 NodeRunnerDropped,
114}
115
116impl From<oneshot::Canceled> for GetClosestPeersError {
117 #[inline]
118 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
119 Self::NodeRunnerDropped
120 }
121}
122
123#[derive(Debug, Error)]
125pub enum GetClosestLocalPeersError {
126 #[error("Failed to send command to the node runner: {0}")]
128 SendCommand(#[from] mpsc::SendError),
129 #[error("Node runner was dropped")]
131 NodeRunnerDropped,
132}
133
134impl From<oneshot::Canceled> for GetClosestLocalPeersError {
135 #[inline]
136 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
137 Self::NodeRunnerDropped
138 }
139}
140
141#[derive(Debug, Error)]
143pub enum SubscribeError {
144 #[error("Failed to send command to the node runner: {0}")]
146 SendCommand(#[from] mpsc::SendError),
147 #[error("Node runner was dropped")]
149 NodeRunnerDropped,
150 #[error("Failed to create subscription: {0}")]
152 Subscription(#[from] SubscriptionError),
153}
154
155impl From<oneshot::Canceled> for SubscribeError {
156 #[inline]
157 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
158 Self::NodeRunnerDropped
159 }
160}
161
162#[derive(Debug, Error)]
163pub enum PublishError {
164 #[error("Failed to send command to the node runner: {0}")]
166 SendCommand(#[from] mpsc::SendError),
167 #[error("Node runner was dropped")]
169 NodeRunnerDropped,
170 #[error("Failed to publish message: {0}")]
172 Publish(#[from] libp2p::gossipsub::PublishError),
173}
174
175impl From<oneshot::Canceled> for PublishError {
176 #[inline]
177 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
178 Self::NodeRunnerDropped
179 }
180}
181
182#[derive(Debug, Error)]
183pub enum GetProvidersError {
184 #[error("Failed to send command to the node runner: {0}")]
186 SendCommand(#[from] mpsc::SendError),
187 #[error("Node runner was dropped")]
189 NodeRunnerDropped,
190 #[error("Failed to get providers.")]
192 GetProviders,
193}
194
195impl From<oneshot::Canceled> for GetProvidersError {
196 #[inline]
197 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
198 Self::NodeRunnerDropped
199 }
200}
201
202#[derive(Debug, Error)]
204pub enum SendRequestError {
205 #[error("Failed to send command to the node runner: {0}")]
207 SendCommand(#[from] mpsc::SendError),
208 #[error("Node runner was dropped")]
210 NodeRunnerDropped,
211 #[error("Underlying protocol returned an error: {0}")]
213 ProtocolFailure(#[from] request_response_factory::RequestFailure),
214 #[error("Received incorrectly formatted response: {0}")]
216 IncorrectResponseFormat(#[from] parity_scale_codec::Error),
217}
218
219impl From<oneshot::Canceled> for SendRequestError {
220 #[inline]
221 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
222 Self::NodeRunnerDropped
223 }
224}
225
226#[derive(Debug, Error)]
227pub enum ConnectedPeersError {
228 #[error("Failed to send command to the node runner: {0}")]
230 SendCommand(#[from] mpsc::SendError),
231 #[error("Node runner was dropped")]
233 NodeRunnerDropped,
234 #[error("Failed to get connected peers.")]
236 ConnectedPeers,
237}
238
239impl From<oneshot::Canceled> for ConnectedPeersError {
240 #[inline]
241 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
242 Self::NodeRunnerDropped
243 }
244}
245
246#[derive(Debug, Error)]
247pub enum BootstrapError {
248 #[error("Failed to send command to the node runner: {0}")]
250 SendCommand(#[from] mpsc::SendError),
251 #[error("Node runner was dropped")]
253 NodeRunnerDropped,
254 #[error("Failed to bootstrap a peer.")]
256 Bootstrap,
257}
258
259impl From<oneshot::Canceled> for BootstrapError {
260 #[inline]
261 fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
262 Self::NodeRunnerDropped
263 }
264}
265
266#[derive(Debug, Clone)]
268#[must_use = "Node doesn't do anything if dropped"]
269pub struct Node {
270 shared: Arc<Shared>,
271}
272
273impl Node {
274 pub(crate) fn new(shared: Arc<Shared>) -> Self {
275 Self { shared }
276 }
277
278 pub fn id(&self) -> PeerId {
280 self.shared.id
281 }
282
283 pub async fn get_value(
285 &self,
286 key: Multihash,
287 ) -> Result<impl Stream<Item = PeerRecord>, GetValueError> {
288 let permit = self.shared.rate_limiter.acquire_permit().await;
289 let (result_sender, result_receiver) = mpsc::unbounded();
290
291 self.shared
292 .command_sender
293 .clone()
294 .send(Command::GetValue {
295 key,
296 result_sender,
297 permit,
298 })
299 .await?;
300
301 Ok(result_receiver)
303 }
304
305 pub async fn put_value(
307 &self,
308 key: Multihash,
309 value: Vec<u8>,
310 ) -> Result<impl Stream<Item = ()>, PutValueError> {
311 let permit = self.shared.rate_limiter.acquire_permit().await;
312 let (result_sender, result_receiver) = mpsc::unbounded();
313
314 self.shared
315 .command_sender
316 .clone()
317 .send(Command::PutValue {
318 key,
319 value,
320 result_sender,
321 permit,
322 })
323 .await?;
324
325 Ok(result_receiver)
327 }
328
329 pub async fn subscribe(&self, topic: Sha256Topic) -> Result<TopicSubscription, SubscribeError> {
331 let permit = self.shared.rate_limiter.acquire_permit().await;
332 let (result_sender, result_receiver) = oneshot::channel();
333
334 self.shared
335 .command_sender
336 .clone()
337 .send(Command::Subscribe {
338 topic: topic.clone(),
339 result_sender,
340 })
341 .await?;
342
343 let CreatedSubscription {
344 subscription_id,
345 receiver,
346 } = result_receiver.await??;
347
348 Ok(TopicSubscription {
349 topic: Some(topic),
350 subscription_id,
351 command_sender: Some(self.shared.command_sender.clone()),
352 receiver,
353 _permit: permit,
354 })
355 }
356
357 pub async fn publish(&self, topic: Sha256Topic, message: Vec<u8>) -> Result<(), PublishError> {
359 let _permit = self.shared.rate_limiter.acquire_permit().await;
360 let (result_sender, result_receiver) = oneshot::channel();
361
362 self.shared
363 .command_sender
364 .clone()
365 .send(Command::Publish {
366 topic,
367 message,
368 result_sender,
369 })
370 .await?;
371
372 result_receiver.await?.map_err(PublishError::Publish)
373 }
374
375 async fn send_generic_request_internal<Request>(
376 &self,
377 peer_id: PeerId,
378 addresses: Vec<Multiaddr>,
379 request: Request,
380 acquire_permit: bool,
381 ) -> Result<Request::Response, SendRequestError>
382 where
383 Request: GenericRequest,
384 {
385 let _permit = if acquire_permit {
386 Some(self.shared.rate_limiter.acquire_permit().await)
387 } else {
388 None
389 };
390
391 let (result_sender, result_receiver) = oneshot::channel();
392 let command = Command::GenericRequest {
393 peer_id,
394 addresses,
395 protocol_name: Request::PROTOCOL_NAME,
396 request: request.encode(),
397 result_sender,
398 };
399
400 self.shared.command_sender.clone().send(command).await?;
401
402 let result = result_receiver.await??;
403
404 Request::Response::decode(&mut result.as_slice()).map_err(Into::into)
405 }
406
407 pub async fn send_generic_request<Request>(
411 &self,
412 peer_id: PeerId,
413 addresses: Vec<Multiaddr>,
414 request: Request,
415 ) -> Result<Request::Response, SendRequestError>
416 where
417 Request: GenericRequest,
418 {
419 self.send_generic_request_internal(peer_id, addresses, request, true)
420 .await
421 }
422
423 pub async fn get_closest_peers(
425 &self,
426 key: Multihash,
427 ) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
428 self.get_closest_peers_internal(key, true).await
429 }
430
431 pub async fn get_closest_local_peers(
437 &self,
438 key: Multihash,
439 source: Option<PeerId>,
440 ) -> Result<Vec<(PeerId, Vec<Multiaddr>)>, GetClosestLocalPeersError> {
441 let (result_sender, result_receiver) = oneshot::channel();
442
443 self.shared
444 .command_sender
445 .clone()
446 .send(Command::GetClosestLocalPeers {
447 key,
448 source,
449 result_sender,
450 })
451 .await?;
452
453 Ok(result_receiver.await?)
454 }
455
456 async fn get_closest_peers_internal(
458 &self,
459 key: Multihash,
460 acquire_permit: bool,
461 ) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
462 let permit = if acquire_permit {
463 Some(self.shared.rate_limiter.acquire_permit().await)
464 } else {
465 None
466 };
467
468 trace!(?key, "Starting 'GetClosestPeers' request.");
469
470 let (result_sender, result_receiver) = mpsc::unbounded();
471
472 self.shared
473 .command_sender
474 .clone()
475 .send(Command::GetClosestPeers {
476 key,
477 result_sender,
478 permit,
479 })
480 .await?;
481
482 Ok(result_receiver)
484 }
485
486 pub async fn get_providers(
488 &self,
489 key: RecordKey,
490 ) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
491 self.get_providers_internal(key, true).await
492 }
493
494 async fn get_providers_internal(
495 &self,
496 key: RecordKey,
497 acquire_permit: bool,
498 ) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
499 let permit = if acquire_permit {
500 Some(self.shared.rate_limiter.acquire_permit().await)
501 } else {
502 None
503 };
504
505 let (result_sender, result_receiver) = mpsc::unbounded();
506
507 trace!(key = hex::encode(&key), "Starting 'get_providers' request");
508
509 self.shared
510 .command_sender
511 .clone()
512 .send(Command::GetProviders {
513 key,
514 result_sender,
515 permit,
516 })
517 .await?;
518
519 Ok(result_receiver)
521 }
522
523 pub async fn ban_peer(&self, peer_id: PeerId) -> Result<(), mpsc::SendError> {
525 self.shared
526 .command_sender
527 .clone()
528 .send(Command::BanPeer { peer_id })
529 .await
530 }
531
532 #[doc(hidden)]
536 pub async fn dial(&self, address: Multiaddr) -> Result<(), mpsc::SendError> {
537 self.shared
538 .command_sender
539 .clone()
540 .send(Command::Dial { address })
541 .await
542 }
543
544 pub fn listeners(&self) -> Vec<Multiaddr> {
546 self.shared.listeners.lock().clone()
547 }
548
549 pub fn external_addresses(&self) -> Vec<Multiaddr> {
551 self.shared.external_addresses.lock().clone()
552 }
553
554 pub fn on_new_listener(&self, callback: HandlerFn<Multiaddr>) -> HandlerId {
556 self.shared.handlers.new_listener.add(callback)
557 }
558
559 pub fn on_num_established_peer_connections_change(
561 &self,
562 callback: HandlerFn<usize>,
563 ) -> HandlerId {
564 self.shared
565 .handlers
566 .num_established_peer_connections_change
567 .add(callback)
568 }
569
570 pub async fn connected_peers(&self) -> Result<Vec<PeerId>, ConnectedPeersError> {
572 let (result_sender, result_receiver) = oneshot::channel();
573
574 trace!("Starting `connected_peers` request");
575
576 self.shared
577 .command_sender
578 .clone()
579 .send(Command::ConnectedPeers { result_sender })
580 .await?;
581
582 result_receiver
583 .await
584 .map_err(|_| ConnectedPeersError::ConnectedPeers)
585 }
586
587 pub async fn connected_servers(&self) -> Result<Vec<PeerId>, ConnectedPeersError> {
589 let (result_sender, result_receiver) = oneshot::channel();
590
591 trace!("Starting `connected_servers` request.");
592
593 self.shared
594 .command_sender
595 .clone()
596 .send(Command::ConnectedServers { result_sender })
597 .await?;
598
599 result_receiver
600 .await
601 .map_err(|_| ConnectedPeersError::ConnectedPeers)
602 }
603
604 pub async fn bootstrap(&self) -> Result<(), BootstrapError> {
606 let (result_sender, mut result_receiver) = mpsc::unbounded();
607
608 debug!("Starting `bootstrap` request");
609
610 self.shared
611 .command_sender
612 .clone()
613 .send(Command::Bootstrap {
614 result_sender: Some(result_sender),
615 })
616 .await?;
617
618 for step in 0.. {
619 let result = result_receiver.next().await;
620
621 if result.is_some() {
622 debug!(%step, "Kademlia bootstrapping...");
623 } else {
624 break;
625 }
626 }
627
628 Ok(())
629 }
630
631 pub fn on_connected_peer(&self, callback: HandlerFn<PeerId>) -> HandlerId {
633 self.shared.handlers.connected_peer.add(callback)
634 }
635
636 pub fn on_disconnected_peer(&self, callback: HandlerFn<PeerId>) -> HandlerId {
638 self.shared.handlers.disconnected_peer.add(callback)
639 }
640
641 pub fn on_discovered_peer(&self, callback: HandlerFn<PeerDiscovered>) -> HandlerId {
643 self.shared.handlers.peer_discovered.add(callback)
644 }
645
646 pub async fn get_requests_batch_handle(&self) -> NodeRequestsBatchHandle {
648 let _permit = self.shared.rate_limiter.acquire_permit().await;
649
650 NodeRequestsBatchHandle {
651 _permit,
652 node: self.clone(),
653 }
654 }
655
656 pub fn downgrade(&self) -> WeakNode {
658 WeakNode {
659 shared: Arc::downgrade(&self.shared),
660 }
661 }
662}
663
664#[derive(Debug, Clone)]
666pub struct WeakNode {
667 shared: Weak<Shared>,
668}
669
670impl WeakNode {
671 pub fn upgrade(&self) -> Option<Node> {
673 self.shared.upgrade().map(|shared| Node { shared })
674 }
675}
676
677pub struct NodeRequestsBatchHandle {
682 node: Node,
683 _permit: OwnedSemaphorePermit,
684}
685
686impl NodeRequestsBatchHandle {
687 pub async fn get_providers(
689 &self,
690 key: RecordKey,
691 ) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
692 self.node.get_providers_internal(key, false).await
693 }
694
695 pub async fn get_closest_peers(
697 &self,
698 key: Multihash,
699 ) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
700 self.node.get_closest_peers_internal(key, false).await
701 }
702 pub async fn send_generic_request<Request>(
706 &self,
707 peer_id: PeerId,
708 addresses: Vec<Multiaddr>,
709 request: Request,
710 ) -> Result<Request::Response, SendRequestError>
711 where
712 Request: GenericRequest,
713 {
714 self.node
715 .send_generic_request_internal(peer_id, addresses, request, false)
716 .await
717 }
718}