subspace_networking/constructor/
transport.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use crate::constructor::temporary_bans::TemporaryBans;
use libp2p::core::multiaddr::{Multiaddr, Protocol};
use libp2p::core::muxing::StreamMuxerBox;
use libp2p::core::transport::{Boxed, DialOpts, ListenerId, TransportError, TransportEvent};
use libp2p::core::Transport;
use libp2p::dns::tokio::Transport as TokioTransport;
use libp2p::tcp::tokio::Transport as TokioTcpTransport;
use libp2p::tcp::Config as GenTcpConfig;
use libp2p::yamux::Config as YamuxConfig;
use libp2p::{core, identity, noise, PeerId};
use parking_lot::Mutex;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tracing::debug;

// Builds the transport stack that LibP2P will communicate over along with a relay client.
pub(super) fn build_transport(
    allow_non_global_addresses_in_dht: bool,
    keypair: &identity::Keypair,
    temporary_bans: Arc<Mutex<TemporaryBans>>,
    timeout: Duration,
    yamux_config: YamuxConfig,
) -> io::Result<Boxed<(PeerId, StreamMuxerBox)>> {
    let wrapped_tcp = {
        let tcp_config = GenTcpConfig::default().nodelay(true);

        CustomTransportWrapper::new(
            TokioTcpTransport::new(tcp_config),
            allow_non_global_addresses_in_dht,
            temporary_bans,
        )
    };

    let tcp_upgraded = {
        let noise =
            noise::Config::new(keypair).expect("Signing libp2p-noise static DH keypair failed.");

        wrapped_tcp
            .upgrade(core::upgrade::Version::V1Lazy)
            .authenticate(noise)
            .multiplex(yamux_config)
            .timeout(timeout)
            .boxed()
    };

    Ok(TokioTransport::system(tcp_upgraded)?.boxed())
}

#[derive(Debug, Clone)]
struct CustomTransportWrapper<T> {
    base_transport: T,
    allow_non_global_addresses: bool,
    temporary_bans: Arc<Mutex<TemporaryBans>>,
}

impl<T> CustomTransportWrapper<T> {
    fn new(
        base_transport: T,
        allow_non_global_addresses: bool,
        temporary_bans: Arc<Mutex<TemporaryBans>>,
    ) -> Self {
        CustomTransportWrapper {
            base_transport,
            allow_non_global_addresses,
            temporary_bans,
        }
    }
}

impl<T> Transport for CustomTransportWrapper<T>
where
    T: Transport + Unpin,
    T::Error: From<io::Error>,
{
    type Output = T::Output;
    type Error = T::Error;
    type ListenerUpgrade = T::ListenerUpgrade;
    type Dial = T::Dial;

    fn listen_on(
        &mut self,
        id: ListenerId,
        addr: Multiaddr,
    ) -> Result<(), TransportError<Self::Error>> {
        self.base_transport.listen_on(id, addr)
    }

    fn remove_listener(&mut self, id: ListenerId) -> bool {
        self.base_transport.remove_listener(id)
    }

    fn dial(
        &mut self,
        addr: Multiaddr,
        opts: DialOpts,
    ) -> Result<Self::Dial, TransportError<Self::Error>> {
        let mut addr_iter = addr.iter();

        match addr_iter.next() {
            Some(Protocol::Ip4(a)) => {
                if !(self.allow_non_global_addresses || a.is_global()) {
                    debug!(?a, "Not dialing non global IP address.",);
                    return Err(TransportError::MultiaddrNotSupported(addr));
                }
            }
            Some(Protocol::Ip6(a)) => {
                if !(self.allow_non_global_addresses || a.is_global()) {
                    debug!(?a, "Not dialing non global IP address.");
                    return Err(TransportError::MultiaddrNotSupported(addr));
                }
            }
            _ => {
                // TODO: This will not catch DNS records pointing to private addresses
            }
        }

        {
            let temporary_bans = self.temporary_bans.lock();
            for protocol in addr_iter {
                if let Protocol::P2p(peer_id) = protocol {
                    if temporary_bans.is_banned(&peer_id) {
                        let error =
                            io::Error::new(io::ErrorKind::Other, "Peer is temporarily banned");
                        return Err(TransportError::Other(error.into()));
                    }
                }
            }
        }

        self.base_transport.dial(addr, opts)
    }

    fn poll(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
        Pin::new(&mut self.base_transport).poll(cx)
    }
}