#[cfg(test)]
mod tests;
use crate::protocols::request_response::handlers::generic_request_handler::GenericRequest;
use crate::protocols::request_response::request_response_factory;
use crate::shared::{Command, CreatedSubscription, PeerDiscovered, Shared};
use crate::utils::multihash::Multihash;
use crate::utils::HandlerFn;
use bytes::Bytes;
use event_listener_primitives::HandlerId;
use futures::channel::{mpsc, oneshot};
use futures::{SinkExt, Stream, StreamExt};
use libp2p::gossipsub::{Sha256Topic, SubscriptionError};
use libp2p::kad::{PeerRecord, RecordKey};
use libp2p::{Multiaddr, PeerId};
use parity_scale_codec::Decode;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use thiserror::Error;
use tokio::sync::OwnedSemaphorePermit;
use tracing::{debug, error, trace};
#[derive(Debug)]
#[pin_project::pin_project(PinnedDrop)]
pub struct TopicSubscription {
topic: Option<Sha256Topic>,
subscription_id: usize,
command_sender: Option<mpsc::Sender<Command>>,
#[pin]
receiver: mpsc::UnboundedReceiver<Bytes>,
_permit: OwnedSemaphorePermit,
}
impl Stream for TopicSubscription {
type Item = Bytes;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().receiver.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.receiver.size_hint()
}
}
#[pin_project::pinned_drop]
impl PinnedDrop for TopicSubscription {
fn drop(mut self: Pin<&mut Self>) {
let topic = self
.topic
.take()
.expect("Always specified on creation and only removed on drop; qed");
let subscription_id = self.subscription_id;
let mut command_sender = self
.command_sender
.take()
.expect("Always specified on creation and only removed on drop; qed");
tokio::spawn(async move {
let _ = command_sender
.send(Command::Unsubscribe {
topic,
subscription_id,
})
.await;
});
}
}
#[derive(Debug, Error)]
pub enum GetValueError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
}
impl From<oneshot::Canceled> for GetValueError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum PutValueError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
}
impl From<oneshot::Canceled> for PutValueError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum GetClosestPeersError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
}
impl From<oneshot::Canceled> for GetClosestPeersError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum GetClosestLocalPeersError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
}
impl From<oneshot::Canceled> for GetClosestLocalPeersError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum SubscribeError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Failed to create subscription: {0}")]
Subscription(#[from] SubscriptionError),
}
impl From<oneshot::Canceled> for SubscribeError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum PublishError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Failed to publish message: {0}")]
Publish(#[from] libp2p::gossipsub::PublishError),
}
impl From<oneshot::Canceled> for PublishError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum GetProvidersError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Failed to get providers.")]
GetProviders,
}
impl From<oneshot::Canceled> for GetProvidersError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum SendRequestError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Underlying protocol returned an error: {0}")]
ProtocolFailure(#[from] request_response_factory::RequestFailure),
#[error("Received incorrectly formatted response: {0}")]
IncorrectResponseFormat(#[from] parity_scale_codec::Error),
}
impl From<oneshot::Canceled> for SendRequestError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum ConnectedPeersError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Failed to get connected peers.")]
ConnectedPeers,
}
impl From<oneshot::Canceled> for ConnectedPeersError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Error)]
pub enum BootstrapError {
#[error("Failed to send command to the node runner: {0}")]
SendCommand(#[from] mpsc::SendError),
#[error("Node runner was dropped")]
NodeRunnerDropped,
#[error("Failed to bootstrap a peer.")]
Bootstrap,
}
impl From<oneshot::Canceled> for BootstrapError {
#[inline]
fn from(oneshot::Canceled: oneshot::Canceled) -> Self {
Self::NodeRunnerDropped
}
}
#[derive(Debug, Clone)]
#[must_use = "Node doesn't do anything if dropped"]
pub struct Node {
shared: Arc<Shared>,
}
impl Node {
pub(crate) fn new(shared: Arc<Shared>) -> Self {
Self { shared }
}
pub fn id(&self) -> PeerId {
self.shared.id
}
pub async fn get_value(
&self,
key: Multihash,
) -> Result<impl Stream<Item = PeerRecord>, GetValueError> {
let permit = self.shared.rate_limiter.acquire_permit().await;
let (result_sender, result_receiver) = mpsc::unbounded();
self.shared
.command_sender
.clone()
.send(Command::GetValue {
key,
result_sender,
permit,
})
.await?;
Ok(result_receiver)
}
pub async fn put_value(
&self,
key: Multihash,
value: Vec<u8>,
) -> Result<impl Stream<Item = ()>, PutValueError> {
let permit = self.shared.rate_limiter.acquire_permit().await;
let (result_sender, result_receiver) = mpsc::unbounded();
self.shared
.command_sender
.clone()
.send(Command::PutValue {
key,
value,
result_sender,
permit,
})
.await?;
Ok(result_receiver)
}
pub async fn subscribe(&self, topic: Sha256Topic) -> Result<TopicSubscription, SubscribeError> {
let permit = self.shared.rate_limiter.acquire_permit().await;
let (result_sender, result_receiver) = oneshot::channel();
self.shared
.command_sender
.clone()
.send(Command::Subscribe {
topic: topic.clone(),
result_sender,
})
.await?;
let CreatedSubscription {
subscription_id,
receiver,
} = result_receiver.await??;
Ok(TopicSubscription {
topic: Some(topic),
subscription_id,
command_sender: Some(self.shared.command_sender.clone()),
receiver,
_permit: permit,
})
}
pub async fn publish(&self, topic: Sha256Topic, message: Vec<u8>) -> Result<(), PublishError> {
let _permit = self.shared.rate_limiter.acquire_permit().await;
let (result_sender, result_receiver) = oneshot::channel();
self.shared
.command_sender
.clone()
.send(Command::Publish {
topic,
message,
result_sender,
})
.await?;
result_receiver.await?.map_err(PublishError::Publish)
}
async fn send_generic_request_internal<Request>(
&self,
peer_id: PeerId,
addresses: Vec<Multiaddr>,
request: Request,
acquire_permit: bool,
) -> Result<Request::Response, SendRequestError>
where
Request: GenericRequest,
{
let _permit = if acquire_permit {
Some(self.shared.rate_limiter.acquire_permit().await)
} else {
None
};
let (result_sender, result_receiver) = oneshot::channel();
let command = Command::GenericRequest {
peer_id,
addresses,
protocol_name: Request::PROTOCOL_NAME,
request: request.encode(),
result_sender,
};
self.shared.command_sender.clone().send(command).await?;
let result = result_receiver.await??;
Request::Response::decode(&mut result.as_slice()).map_err(Into::into)
}
pub async fn send_generic_request<Request>(
&self,
peer_id: PeerId,
addresses: Vec<Multiaddr>,
request: Request,
) -> Result<Request::Response, SendRequestError>
where
Request: GenericRequest,
{
self.send_generic_request_internal(peer_id, addresses, request, true)
.await
}
pub async fn get_closest_peers(
&self,
key: Multihash,
) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
self.get_closest_peers_internal(key, true).await
}
pub async fn get_closest_local_peers(
&self,
key: Multihash,
source: Option<PeerId>,
) -> Result<Vec<(PeerId, Vec<Multiaddr>)>, GetClosestLocalPeersError> {
let (result_sender, result_receiver) = oneshot::channel();
self.shared
.command_sender
.clone()
.send(Command::GetClosestLocalPeers {
key,
source,
result_sender,
})
.await?;
Ok(result_receiver.await?)
}
async fn get_closest_peers_internal(
&self,
key: Multihash,
acquire_permit: bool,
) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
let permit = if acquire_permit {
Some(self.shared.rate_limiter.acquire_permit().await)
} else {
None
};
trace!(?key, "Starting 'GetClosestPeers' request.");
let (result_sender, result_receiver) = mpsc::unbounded();
self.shared
.command_sender
.clone()
.send(Command::GetClosestPeers {
key,
result_sender,
permit,
})
.await?;
Ok(result_receiver)
}
pub async fn get_providers(
&self,
key: RecordKey,
) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
self.get_providers_internal(key, true).await
}
async fn get_providers_internal(
&self,
key: RecordKey,
acquire_permit: bool,
) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
let permit = if acquire_permit {
Some(self.shared.rate_limiter.acquire_permit().await)
} else {
None
};
let (result_sender, result_receiver) = mpsc::unbounded();
trace!(key = hex::encode(&key), "Starting 'get_providers' request");
self.shared
.command_sender
.clone()
.send(Command::GetProviders {
key,
result_sender,
permit,
})
.await?;
Ok(result_receiver)
}
pub async fn ban_peer(&self, peer_id: PeerId) -> Result<(), mpsc::SendError> {
self.shared
.command_sender
.clone()
.send(Command::BanPeer { peer_id })
.await
}
#[doc(hidden)]
pub async fn dial(&self, address: Multiaddr) -> Result<(), mpsc::SendError> {
self.shared
.command_sender
.clone()
.send(Command::Dial { address })
.await
}
pub fn listeners(&self) -> Vec<Multiaddr> {
self.shared.listeners.lock().clone()
}
pub fn external_addresses(&self) -> Vec<Multiaddr> {
self.shared.external_addresses.lock().clone()
}
pub fn on_new_listener(&self, callback: HandlerFn<Multiaddr>) -> HandlerId {
self.shared.handlers.new_listener.add(callback)
}
pub fn on_num_established_peer_connections_change(
&self,
callback: HandlerFn<usize>,
) -> HandlerId {
self.shared
.handlers
.num_established_peer_connections_change
.add(callback)
}
pub async fn connected_peers(&self) -> Result<Vec<PeerId>, ConnectedPeersError> {
let (result_sender, result_receiver) = oneshot::channel();
trace!("Starting `connected_peers` request");
self.shared
.command_sender
.clone()
.send(Command::ConnectedPeers { result_sender })
.await?;
result_receiver
.await
.map_err(|_| ConnectedPeersError::ConnectedPeers)
}
pub async fn connected_servers(&self) -> Result<Vec<PeerId>, ConnectedPeersError> {
let (result_sender, result_receiver) = oneshot::channel();
trace!("Starting `connected_servers` request.");
self.shared
.command_sender
.clone()
.send(Command::ConnectedServers { result_sender })
.await?;
result_receiver
.await
.map_err(|_| ConnectedPeersError::ConnectedPeers)
}
pub async fn bootstrap(&self) -> Result<(), BootstrapError> {
let (result_sender, mut result_receiver) = mpsc::unbounded();
debug!("Starting `bootstrap` request");
self.shared
.command_sender
.clone()
.send(Command::Bootstrap {
result_sender: Some(result_sender),
})
.await?;
for step in 0.. {
let result = result_receiver.next().await;
if result.is_some() {
debug!(%step, "Kademlia bootstrapping...");
} else {
break;
}
}
Ok(())
}
pub fn on_connected_peer(&self, callback: HandlerFn<PeerId>) -> HandlerId {
self.shared.handlers.connected_peer.add(callback)
}
pub fn on_disconnected_peer(&self, callback: HandlerFn<PeerId>) -> HandlerId {
self.shared.handlers.disconnected_peer.add(callback)
}
pub fn on_discovered_peer(&self, callback: HandlerFn<PeerDiscovered>) -> HandlerId {
self.shared.handlers.peer_discovered.add(callback)
}
pub async fn get_requests_batch_handle(&self) -> NodeRequestsBatchHandle {
let _permit = self.shared.rate_limiter.acquire_permit().await;
NodeRequestsBatchHandle {
_permit,
node: self.clone(),
}
}
pub fn downgrade(&self) -> WeakNode {
WeakNode {
shared: Arc::downgrade(&self.shared),
}
}
}
#[derive(Debug, Clone)]
pub struct WeakNode {
shared: Weak<Shared>,
}
impl WeakNode {
pub fn upgrade(&self) -> Option<Node> {
self.shared.upgrade().map(|shared| Node { shared })
}
}
pub struct NodeRequestsBatchHandle {
node: Node,
_permit: OwnedSemaphorePermit,
}
impl NodeRequestsBatchHandle {
pub async fn get_providers(
&mut self,
key: RecordKey,
) -> Result<impl Stream<Item = PeerId>, GetProvidersError> {
self.node.get_providers_internal(key, false).await
}
pub async fn get_closest_peers(
&mut self,
key: Multihash,
) -> Result<impl Stream<Item = PeerId>, GetClosestPeersError> {
self.node.get_closest_peers_internal(key, false).await
}
pub async fn send_generic_request<Request>(
&mut self,
peer_id: PeerId,
addresses: Vec<Multiaddr>,
request: Request,
) -> Result<Request::Response, SendRequestError>
where
Request: GenericRequest,
{
self.node
.send_generic_request_internal(peer_id, addresses, request, false)
.await
}
}