Make sure handshake version negotiation always has a timeout

As part of this change, refactor handshake version negotiation into its
own function.
This commit is contained in:
teor 2021-04-13 13:30:17 +10:00 committed by Deirdre Connolly
parent 43e792b9a4
commit ad272f2bee
1 changed files with 157 additions and 127 deletions

View File

@ -12,7 +12,7 @@ use futures::{
channel::{mpsc, oneshot},
future, FutureExt, SinkExt, StreamExt,
};
use tokio::{net::TcpStream, sync::broadcast, time::timeout};
use tokio::{net::TcpStream, sync::broadcast, task::JoinError, time::timeout};
use tokio_util::codec::Framed;
use tower::Service;
use tracing::{span, Level, Span};
@ -180,55 +180,21 @@ where
}
}
impl<S> Service<(TcpStream, SocketAddr)> for Handshake<S>
where
S: Service<Request, Response = Response, Error = BoxError> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Client;
type Error = BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (TcpStream, SocketAddr)) -> Self::Future {
let (tcp_stream, addr) = req;
let connector_span = span!(Level::INFO, "connector", addr = ?addr);
// set the peer connection span's parent to the global span, as it
// should exist independently of its creation source (inbound
// connection, crawler, initial peer, ...)
let connection_span = span!(parent: &self.parent_span, Level::INFO, "peer", addr = ?addr);
// Clone these upfront, so they can be moved into the future.
let nonces = self.nonces.clone();
let inbound_service = self.inbound_service.clone();
let timestamp_collector = self.timestamp_collector.clone();
let inv_collector = self.inv_collector.clone();
let network = self.config.network;
let our_addr = self.config.listen_addr;
let user_agent = self.user_agent.clone();
let our_services = self.our_services;
let relay = self.relay;
let fut = async move {
debug!("connecting to remote peer");
// CORRECTNESS
//
// As a defence-in-depth against hangs, every send or next on stream
// should be wrapped in a timeout.
let mut stream = Framed::new(
tcp_stream,
Codec::builder()
.for_network(network)
.with_metrics_label(addr.ip().to_string())
.finish(),
);
/// Negotiate the Zcash network protocol version with the remote peer
/// at `addr`, using the connection `peer_conn`.
///
/// We split `Handshake` into its components before calling this function,
/// to avoid infectious `Sync` bounds on the returned future.
pub async fn negotiate_version(
peer_conn: &mut Framed<TcpStream, Codec>,
addr: &SocketAddr,
config: Config,
nonces: Arc<Mutex<HashSet<Nonce>>>,
user_agent: String,
our_services: PeerServices,
relay: bool,
) -> Result<(Version, PeerServices), HandshakeError> {
// Create a random nonce for this connection
let local_nonce = Nonce::default();
nonces
.lock()
@ -255,25 +221,27 @@ where
let now = Utc::now().timestamp();
let timestamp = Utc.timestamp(now - now.rem_euclid(5 * 60), 0);
let version = Message::Version {
let our_version = Message::Version {
version: constants::CURRENT_VERSION,
services: our_services,
timestamp,
address_recv: (PeerServices::NODE_NETWORK, addr),
address_from: (our_services, our_addr),
address_recv: (PeerServices::NODE_NETWORK, *addr),
// TODO: detect external address (#1893)
address_from: (our_services, config.listen_addr),
nonce: local_nonce,
user_agent,
user_agent: user_agent.clone(),
// The protocol works fine if we don't reveal our current block height,
// and not sending it means we don't need to be connected to the chain state.
start_height: block::Height(0),
relay,
};
debug!(?version, "sending initial version message");
timeout(constants::REQUEST_TIMEOUT, stream.send(version)).await??;
debug!(?our_version, "sending initial version message");
peer_conn.send(our_version).await?;
let remote_msg = timeout(constants::REQUEST_TIMEOUT, stream.next())
.await?
let remote_msg = peer_conn
.next()
.await
.ok_or(HandshakeError::ConnectionClosed)??;
// Check that we got a Version and destructure its fields into the local scope.
@ -287,7 +255,7 @@ where
{
(nonce, services, version)
} else {
return Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)));
Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)))?
};
// Check for nonce reuse, indicating self-connection.
@ -299,18 +267,19 @@ where
nonce_reuse
};
if nonce_reuse {
return Err(HandshakeError::NonceReuse);
Err(HandshakeError::NonceReuse)?;
}
timeout(constants::REQUEST_TIMEOUT, stream.send(Message::Verack)).await??;
peer_conn.send(Message::Verack).await?;
let remote_msg = timeout(constants::REQUEST_TIMEOUT, stream.next())
.await?
let remote_msg = peer_conn
.next()
.await
.ok_or(HandshakeError::ConnectionClosed)??;
if let Message::Verack = remote_msg {
debug!("got verack from remote peer");
} else {
return Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)));
Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)))?;
}
// XXX in zcashd remote peer can only send one version message and
@ -335,11 +304,77 @@ where
// configured network, and height is the best tip's block
// height.
if remote_version < Version::min_for_upgrade(network, constants::MIN_NETWORK_UPGRADE) {
if remote_version < Version::min_for_upgrade(config.network, constants::MIN_NETWORK_UPGRADE) {
// Disconnect if peer is using an obsolete version.
return Err(HandshakeError::ObsoleteVersion(remote_version));
Err(HandshakeError::ObsoleteVersion(remote_version))?;
}
Ok((remote_version, remote_services))
}
impl<S> Service<(TcpStream, SocketAddr)> for Handshake<S>
where
S: Service<Request, Response = Response, Error = BoxError> + Clone + Send + 'static,
S::Future: Send,
{
type Response = Client;
type Error = BoxError;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: (TcpStream, SocketAddr)) -> Self::Future {
let (tcp_stream, addr) = req;
let connector_span = span!(Level::INFO, "connector", ?addr);
// set the peer connection span's parent to the global span, as it
// should exist independently of its creation source (inbound
// connection, crawler, initial peer, ...)
let connection_span = span!(parent: &self.parent_span, Level::INFO, "peer", ?addr);
// Clone these upfront, so they can be moved into the future.
let nonces = self.nonces.clone();
let inbound_service = self.inbound_service.clone();
let timestamp_collector = self.timestamp_collector.clone();
let inv_collector = self.inv_collector.clone();
let config = self.config.clone();
let user_agent = self.user_agent.clone();
let our_services = self.our_services;
let relay = self.relay;
let fut = async move {
debug!(?addr, "negotiating protocol version with remote peer");
// CORRECTNESS
//
// As a defence-in-depth against hangs, every send or next on stream
// should be wrapped in a timeout.
let mut peer_conn = Framed::new(
tcp_stream,
Codec::builder()
.for_network(config.network)
.with_metrics_label(addr.ip().to_string())
.finish(),
);
// Wrap the entire initial connection setup in a timeout.
let (remote_version, remote_services) = timeout(
constants::HANDSHAKE_TIMEOUT,
negotiate_version(
&mut peer_conn,
&addr,
config,
nonces,
user_agent,
our_services,
relay,
),
)
.await??;
// Set the connection's version to the minimum of the received version or our own.
let negotiated_version = std::cmp::min(remote_version, constants::CURRENT_VERSION);
@ -348,7 +383,7 @@ where
// XXX The tokio documentation says not to do this while any frames are still being processed.
// Since we don't know that here, another way might be to release the tcp
// stream from the unversioned Framed wrapper and construct a new one with a versioned codec.
let bare_codec = stream.codec_mut();
let bare_codec = peer_conn.codec_mut();
bare_codec.reconfigure_version(negotiated_version);
debug!("constructing client, spawning server");
@ -365,7 +400,7 @@ where
error_slot: slot.clone(),
};
let (peer_tx, peer_rx) = stream.split();
let (peer_tx, peer_rx) = peer_conn.split();
// Instrument the peer's rx and tx streams.
@ -389,6 +424,7 @@ where
// Every message and error must update the peer address state via
// the inbound_ts_collector.
let inbound_ts_collector = timestamp_collector.clone();
let inv_collector = inv_collector.clone();
let peer_rx = peer_rx
.then(move |msg| {
// Add a metric for inbound messages and errors.
@ -554,13 +590,7 @@ where
// Spawn a new task to drive this handshake.
tokio::spawn(fut.instrument(connector_span))
// This is required to get error types to line up.
// Probably there's a nicer way to express this using combinators.
.map(|x| match x {
Ok(Ok(client)) => Ok(client),
Ok(Err(handshake_err)) => Err(handshake_err.into()),
Err(join_err) => Err(join_err.into()),
})
.map(|x: Result<Result<Client, HandshakeError>, JoinError>| Ok(x??))
.boxed()
}
}