diff --git a/zebra-network/src/peer.rs b/zebra-network/src/peer.rs index fab29d70..1711ee9f 100644 --- a/zebra-network/src/peer.rs +++ b/zebra-network/src/peer.rs @@ -12,6 +12,8 @@ mod error; mod handshake; use client::ClientRequest; +use client::ClientRequestReceiver; +use client::InProgressClientRequest; use client::MustUseOneshotSender; use error::ErrorSlot; diff --git a/zebra-network/src/peer/client.rs b/zebra-network/src/peer/client.rs index 5792f655..33f67742 100644 --- a/zebra-network/src/peer/client.rs +++ b/zebra-network/src/peer/client.rs @@ -7,6 +7,7 @@ use std::{ use futures::{ channel::{mpsc, oneshot}, future, ready, + stream::{Stream, StreamExt}, }; use tower::Service; @@ -25,8 +26,32 @@ pub struct Client { /// A message from the `peer::Client` to the `peer::Server`. #[derive(Debug)] -#[must_use = "tx.send() must be called before drop"] pub(super) struct ClientRequest { + /// The actual request. + pub request: Request, + /// The return message channel, included because `peer::Client::call` returns a + /// future that may be moved around before it resolves. + pub tx: oneshot::Sender>, + /// The tracing context for the request, so that work the connection task does + /// processing messages in the context of this request will have correct context. + pub span: tracing::Span, +} + +/// A receiver for the `peer::Server`, which wraps a `mpsc::Receiver`, +/// converting `ClientRequest`s into `InProgressClientRequest`s. +#[derive(Debug)] +pub(super) struct ClientRequestReceiver { + /// The inner receiver + inner: mpsc::Receiver, +} + +/// A message from the `peer::Client` to the `peer::Server`, +/// after it has been received by the `peer::Server`. +/// +/// +#[derive(Debug)] +#[must_use = "tx.send() must be called before drop"] +pub(super) struct InProgressClientRequest { /// The actual request. pub request: Request, /// The return message channel, included because `peer::Client::call` returns a @@ -34,7 +59,15 @@ pub(super) struct ClientRequest { /// /// INVARIANT: `tx.send()` must be called before dropping `tx`. /// - /// JUSTIFICATION: the `peer::Client` will translate all `Request`s into a `ClientRequest` which it sends to a background task, and if the send replies with `Ok(())` it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`. + /// JUSTIFICATION: the `peer::Client` translates `Request`s into + /// `ClientRequest`s, which it sends to a background task. If the send is + /// `Ok(())`, it will assume that it is safe to unconditionally poll the + /// `Receiver` tied to the `Sender` used to create the `ClientRequest`. + /// + /// We enforce this invariant via the type system, by converting + /// `ClientRequest`s to `InProgressClientRequest`s when they are received by + /// the background task. These conversions are implemented by + /// `ClientRequestReceiver`. pub tx: MustUseOneshotSender>, /// The tracing context for the request, so that work the connection task does /// processing messages in the context of this request will have correct context. @@ -54,6 +87,49 @@ pub(super) struct MustUseOneshotSender { pub tx: Option>, } +impl From for InProgressClientRequest { + fn from(client_request: ClientRequest) -> Self { + let ClientRequest { request, tx, span } = client_request; + InProgressClientRequest { + request, + tx: tx.into(), + span, + } + } +} + +impl ClientRequestReceiver { + /// Forwards to `inner.close()` + pub fn close(&mut self) { + self.inner.close() + } +} + +impl Stream for ClientRequestReceiver { + type Item = InProgressClientRequest; + + /// Converts the successful result of `inner.poll_next()` to an + /// `InProgressClientRequest`. + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_next_unpin(cx) { + Poll::Ready(client_request) => Poll::Ready(client_request.map(Into::into)), + // `inner.poll_next_unpin` parks the task for this future + Poll::Pending => Poll::Pending, + } + } + + /// Returns `inner.size_hint()` + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl From> for ClientRequestReceiver { + fn from(rx: mpsc::Receiver) -> Self { + ClientRequestReceiver { inner: rx } + } +} + impl MustUseOneshotSender { /// Forwards `t` to `tx.send()`, and marks this sender as used. /// @@ -143,11 +219,7 @@ impl Service for Client { // request. let span = tracing::Span::current(); - match self.server_tx.try_send(ClientRequest { - request, - span, - tx: tx.into(), - }) { + match self.server_tx.try_send(ClientRequest { request, span, tx }) { Err(e) => { if e.is_disconnected() { let ClientRequest { tx, .. } = e.into_inner(); diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index bf046ea8..6b65160c 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -10,7 +10,6 @@ use std::{collections::HashSet, sync::Arc}; use futures::{ - channel::mpsc, future::{self, Either}, prelude::*, stream::Stream, @@ -34,7 +33,10 @@ use crate::{ BoxError, }; -use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError}; +use super::{ + ClientRequestReceiver, ErrorSlot, InProgressClientRequest, MustUseOneshotSender, PeerError, + SharedPeerError, +}; #[derive(Debug)] pub(super) enum Handler { @@ -327,7 +329,9 @@ pub struct Connection { /// other state handling. pub(super) request_timer: Option, pub(super) svc: S, - pub(super) client_rx: mpsc::Receiver, + /// A `mpsc::Receiver` that converts its results to + /// `InProgressClientRequest` + pub(super) client_rx: ClientRequestReceiver, /// A slot for an error shared between the Connection and the Client that uses it. pub(super) error_slot: ErrorSlot, //pub(super) peer_rx: Rx, @@ -475,7 +479,7 @@ where // requests before we can return and complete the future. State::Failed => { match self.client_rx.next().await { - Some(ClientRequest { tx, span, .. }) => { + Some(InProgressClientRequest { tx, span, .. }) => { trace!( parent: &span, "erroring pending request to failed connection" @@ -535,11 +539,11 @@ where /// /// NOTE: the caller should use .instrument(msg.span) to instrument the function. #[instrument(skip(self))] - async fn handle_client_request(&mut self, req: ClientRequest) { + async fn handle_client_request(&mut self, req: InProgressClientRequest) { trace!(?req.request); use Request::*; use State::*; - let ClientRequest { request, tx, span } = req; + let InProgressClientRequest { request, tx, span } = req; if tx.is_canceled() { metrics::counter!("peer.canceled", 1); diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index 7ae1f011..e022fb92 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -435,7 +435,7 @@ where let server = Connection { state: connection::State::AwaitingRequest, svc: inbound_service, - client_rx: server_rx, + client_rx: server_rx.into(), error_slot: slot, peer_tx, request_timer: None, @@ -451,7 +451,7 @@ where let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat"); tokio::spawn( async move { - use super::client::ClientRequest; + use super::ClientRequest; use futures::future::Either; let mut shutdown_rx = shutdown_rx; @@ -466,16 +466,23 @@ where tracing::trace!(?request, "queueing heartbeat request"); match server_tx.try_send(ClientRequest { request, - tx: tx.into(), + tx, span: tracing::Span::current(), }) { Ok(()) => { match server_tx.flush().await { Ok(()) => {} Err(e) => { - // TODO: we can't get the client request for this failure, - // so we can't ensure the invariant holds - panic!("flushing client request failed: {:?}", e); + // We can't get the client request for this failure, + // so we can't send an error back here. But that's ok, + // because: + // - this error never happens (or it's very rare) + // - if the flush() fails, the server hasn't + // received the request + tracing::warn!( + "flushing client request failed: {:?}", + e + ); } } }