subspace_networking/protocols/request_response/handlers/
generic_request_handler.rs

1//! Generic request-response handler, typically is used with a type implementing [`GenericRequest`]
2//! to significantly reduce boilerplate when implementing [`RequestHandler`].
3
4use crate::protocols::request_response::request_response_factory::{
5    IncomingRequest, OutgoingResponse, ProtocolConfig, RequestHandler,
6};
7use async_trait::async_trait;
8use futures::channel::mpsc;
9use futures::prelude::*;
10use libp2p::PeerId;
11use parity_scale_codec::{Decode, Encode};
12use std::pin::Pin;
13use std::sync::Arc;
14use tracing::{debug, trace};
15
16/// Could be changed after the production feedback.
17const REQUESTS_BUFFER_SIZE: usize = 50;
18
19/// Generic request with associated response
20pub trait GenericRequest: Encode + Decode + Send + Sync + 'static {
21    /// Defines request-response protocol name.
22    const PROTOCOL_NAME: &'static str;
23    /// Specifies log-parameters for tracing.
24    const LOG_TARGET: &'static str;
25    /// Response type that corresponds to this request
26    type Response: Encode + Decode + Send + Sync + 'static;
27}
28
29type RequestHandlerFn<Request> = Arc<
30    dyn (Fn(
31            PeerId,
32            Request,
33        )
34            -> Pin<Box<dyn Future<Output = Option<<Request as GenericRequest>::Response>> + Send>>)
35        + Send
36        + Sync
37        + 'static,
38>;
39
40/// Defines generic request-response protocol handler.
41pub struct GenericRequestHandler<Request: GenericRequest> {
42    request_receiver: mpsc::Receiver<IncomingRequest>,
43    request_handler: RequestHandlerFn<Request>,
44    protocol_config: ProtocolConfig,
45}
46
47impl<Request: GenericRequest> GenericRequestHandler<Request> {
48    /// Creates new [`GenericRequestHandler`] by given handler.
49    pub fn create<RH, Fut>(request_handler: RH) -> Box<dyn RequestHandler>
50    where
51        RH: (Fn(PeerId, Request) -> Fut) + Send + Sync + 'static,
52        Fut: Future<Output = Option<Request::Response>> + Send + 'static,
53    {
54        let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
55
56        let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
57        protocol_config.inbound_queue = Some(request_sender);
58
59        Box::new(Self {
60            request_receiver,
61            request_handler: Arc::new(move |peer_id, request| {
62                Box::pin(request_handler(peer_id, request))
63            }),
64            protocol_config,
65        })
66    }
67
68    /// Invokes external protocol handler.
69    async fn handle_request(
70        &self,
71        peer: PeerId,
72        payload: Vec<u8>,
73    ) -> Result<Vec<u8>, RequestHandlerError> {
74        trace!(%peer, protocol=Request::LOG_TARGET, "Handling request...");
75        let request = Request::decode(&mut payload.as_slice())
76            .map_err(|_| RequestHandlerError::InvalidRequestFormat)?;
77        let response = (self.request_handler)(peer, request).await;
78
79        Ok(response.ok_or(RequestHandlerError::NoResponse)?.encode())
80    }
81}
82
83#[async_trait]
84impl<Request: GenericRequest> RequestHandler for GenericRequestHandler<Request> {
85    /// Run [`RequestHandler`].
86    async fn run(&mut self) {
87        while let Some(request) = self.request_receiver.next().await {
88            let IncomingRequest {
89                peer,
90                payload,
91                pending_response,
92            } = request;
93
94            match self.handle_request(peer, payload).await {
95                Ok(response_data) => {
96                    let response = OutgoingResponse {
97                        result: Ok(response_data),
98                        sent_feedback: None,
99                    };
100
101                    match pending_response.send(response) {
102                        Ok(()) => trace!(target = Request::LOG_TARGET, %peer, "Handled request",),
103                        Err(_) => debug!(
104                            target = Request::LOG_TARGET,
105                            protocol = Request::PROTOCOL_NAME,
106                            %peer,
107                            "Failed to handle request: {}",
108                            RequestHandlerError::SendResponse
109                        ),
110                    };
111                }
112                Err(e) => {
113                    debug!(
114                        target = Request::LOG_TARGET,
115                        protocol = Request::PROTOCOL_NAME,
116                        %e,
117                        "Failed to handle request.",
118                    );
119
120                    let response = OutgoingResponse {
121                        result: Err(()),
122                        sent_feedback: None,
123                    };
124
125                    if pending_response.send(response).is_err() {
126                        debug!(
127                            target = Request::LOG_TARGET,
128                            protocol = Request::PROTOCOL_NAME,
129                            %peer,
130                            "Failed to handle request: {}", RequestHandlerError::SendResponse
131                        );
132                    };
133                }
134            }
135        }
136    }
137
138    fn protocol_config(&self) -> ProtocolConfig {
139        self.protocol_config.clone()
140    }
141
142    fn protocol_name(&self) -> &'static str {
143        Request::PROTOCOL_NAME
144    }
145
146    fn clone_box(&self) -> Box<dyn RequestHandler> {
147        let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
148
149        let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
150        protocol_config.inbound_queue = Some(request_sender);
151
152        Box::new(Self {
153            request_receiver,
154            request_handler: Arc::clone(&self.request_handler),
155            protocol_config,
156        })
157    }
158}
159
160#[derive(Debug, thiserror::Error)]
161enum RequestHandlerError {
162    #[error("Failed to send response.")]
163    SendResponse,
164
165    #[error("Incorrect request format.")]
166    InvalidRequestFormat,
167
168    #[error("No response.")]
169    NoResponse,
170}