diff --git a/Cargo.lock b/Cargo.lock index 3021451f..768b68b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1923,6 +1923,17 @@ dependencies = [ "tower-util", ] +[[package]] +name = "tower-batch" +version = "0.1.0" +dependencies = [ + "futures-core", + "pin-project", + "tokio", + "tower", + "tracing", +] + [[package]] name = "tower-buffer" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 440c177d..4243c625 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "tower-batch", "zebra-chain", "zebra-network", "zebra-state", diff --git a/tower-batch/Cargo.toml b/tower-batch/Cargo.toml new file mode 100644 index 00000000..8290e58c --- /dev/null +++ b/tower-batch/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "tower-batch" +version = "0.1.0" +authors = ["Zcash Foundation "] +license = "MIT" +edition = "2018" + +[dependencies] +tokio = { version = "0.2", features = ["time"] } +tower = "0.3" +futures-core = "0.3.5" +pin-project = "0.4.20" +tracing = "0.1.15" diff --git a/tower-batch/src/error.rs b/tower-batch/src/error.rs new file mode 100644 index 00000000..f8753902 --- /dev/null +++ b/tower-batch/src/error.rs @@ -0,0 +1,65 @@ +//! Error types for the `Buffer` middleware. + +use crate::BoxError; +use std::{fmt, sync::Arc}; + +/// An error produced by a `Service` wrapped by a `Buffer` +#[derive(Debug)] +pub struct ServiceError { + inner: Arc, +} + +/// An error produced when the a buffer's worker closes unexpectedly. +pub struct Closed { + _p: (), +} + +// ===== impl ServiceError ===== + +impl ServiceError { + pub(crate) fn new(inner: BoxError) -> ServiceError { + let inner = Arc::new(inner); + ServiceError { inner } + } + + // Private to avoid exposing `Clone` trait as part of the public API + pub(crate) fn clone(&self) -> ServiceError { + ServiceError { + inner: self.inner.clone(), + } + } +} + +impl fmt::Display for ServiceError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "buffered service failed: {}", self.inner) + } +} + +impl std::error::Error for ServiceError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&**self.inner) + } +} + +// ===== impl Closed ===== + +impl Closed { + pub(crate) fn new() -> Self { + Closed { _p: () } + } +} + +impl fmt::Debug for Closed { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_tuple("Closed").finish() + } +} + +impl fmt::Display for Closed { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str("buffer's worker closed unexpectedly") + } +} + +impl std::error::Error for Closed {} diff --git a/tower-batch/src/future.rs b/tower-batch/src/future.rs new file mode 100644 index 00000000..6b5ae641 --- /dev/null +++ b/tower-batch/src/future.rs @@ -0,0 +1,68 @@ +//! Future types for the `Buffer` middleware. + +use super::{error::Closed, message}; +use futures_core::ready; +use pin_project::{pin_project, project}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +/// Future that completes when the buffered service eventually services the submitted request. +#[pin_project] +#[derive(Debug)] +pub struct ResponseFuture { + #[pin] + state: ResponseState, +} + +#[pin_project] +#[derive(Debug)] +enum ResponseState { + Failed(Option), + Rx(#[pin] message::Rx), + Poll(#[pin] T), +} + +impl ResponseFuture { + pub(crate) fn new(rx: message::Rx) -> Self { + ResponseFuture { + state: ResponseState::Rx(rx), + } + } + + pub(crate) fn failed(err: crate::BoxError) -> Self { + ResponseFuture { + state: ResponseState::Failed(Some(err)), + } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + #[project] + match this.state.as_mut().project() { + ResponseState::Failed(e) => { + return Poll::Ready(Err(e.take().expect("polled after error"))); + } + ResponseState::Rx(rx) => match ready!(rx.poll(cx)) { + Ok(Ok(f)) => this.state.set(ResponseState::Poll(f)), + Ok(Err(e)) => return Poll::Ready(Err(e.into())), + Err(_) => return Poll::Ready(Err(Closed::new().into())), + }, + ResponseState::Poll(fut) => return fut.poll(cx).map_err(Into::into), + } + } + } +} diff --git a/tower-batch/src/layer.rs b/tower-batch/src/layer.rs new file mode 100644 index 00000000..5bddc924 --- /dev/null +++ b/tower-batch/src/layer.rs @@ -0,0 +1,60 @@ +use super::service::Buffer; +use std::{fmt, marker::PhantomData}; +use tower_layer::Layer; +use tower_service::Service; + +/// Adds an mpsc buffer in front of an inner service. +/// +/// The default Tokio executor is used to run the given service, +/// which means that this layer can only be used on the Tokio runtime. +/// +/// See the module documentation for more details. +pub struct BufferLayer { + bound: usize, + _p: PhantomData, +} + +impl BufferLayer { + /// Creates a new `BufferLayer` with the provided `bound`. + /// + /// `bound` gives the maximal number of requests that can be queued for the service before + /// backpressure is applied to callers. + /// + /// # A note on choosing a `bound` + /// + /// When `Buffer`'s implementation of `poll_ready` returns `Poll::Ready`, it reserves a + /// slot in the channel for the forthcoming `call()`. However, if this call doesn't arrive, + /// this reserved slot may be held up for a long time. As a result, it's advisable to set + /// `bound` to be at least the maximum number of concurrent requests the `Buffer` will see. + /// If you do not, all the slots in the buffer may be held up by futures that have just called + /// `poll_ready` but will not issue a `call`, which prevents other senders from issuing new + /// requests. + pub fn new(bound: usize) -> Self { + BufferLayer { + bound, + _p: PhantomData, + } + } +} + +impl Layer for BufferLayer +where + S: Service + Send + 'static, + S::Future: Send, + S::Error: Into + Send + Sync, + Request: Send + 'static, +{ + type Service = Buffer; + + fn layer(&self, service: S) -> Self::Service { + Buffer::new(service, self.bound) + } +} + +impl fmt::Debug for BufferLayer { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BufferLayer") + .field("bound", &self.bound) + .finish() + } +} diff --git a/tower-batch/src/lib.rs b/tower-batch/src/lib.rs new file mode 100644 index 00000000..eebee513 --- /dev/null +++ b/tower-batch/src/lib.rs @@ -0,0 +1,11 @@ +pub mod error; +pub mod future; +mod layer; +mod message; +mod service; +mod worker; + +type BoxError = Box; + +pub use self::layer::BufferLayer; +pub use self::service::Buffer; diff --git a/tower-batch/src/message.rs b/tower-batch/src/message.rs new file mode 100644 index 00000000..6d13aa12 --- /dev/null +++ b/tower-batch/src/message.rs @@ -0,0 +1,16 @@ +use super::error::ServiceError; +use tokio::sync::oneshot; + +/// Message sent over buffer +#[derive(Debug)] +pub(crate) struct Message { + pub(crate) request: Request, + pub(crate) tx: Tx, + pub(crate) span: tracing::Span, +} + +/// Response sender +pub(crate) type Tx = oneshot::Sender>; + +/// Response receiver +pub(crate) type Rx = oneshot::Receiver>; diff --git a/tower-batch/src/service.rs b/tower-batch/src/service.rs new file mode 100644 index 00000000..14d11ab3 --- /dev/null +++ b/tower-batch/src/service.rs @@ -0,0 +1,139 @@ +use super::{ + future::ResponseFuture, + message::Message, + worker::{Handle, Worker}, +}; + +use futures_core::ready; +use std::task::{Context, Poll}; +use tokio::sync::{mpsc, oneshot}; +use tower::Service; + +/// Adds an mpsc buffer in front of an inner service. +/// +/// See the module documentation for more details. +#[derive(Debug)] +pub struct Buffer +where + T: Service, +{ + tx: mpsc::Sender>, + handle: Handle, +} + +impl Buffer +where + T: Service, + T::Error: Into, +{ + /// Creates a new `Buffer` wrapping `service`. + /// + /// `bound` gives the maximal number of requests that can be queued for the service before + /// backpressure is applied to callers. + /// + /// The default Tokio executor is used to run the given service, which means that this method + /// must be called while on the Tokio runtime. + /// + /// # A note on choosing a `bound` + /// + /// When `Buffer`'s implementation of `poll_ready` returns `Poll::Ready`, it reserves a + /// slot in the channel for the forthcoming `call()`. However, if this call doesn't arrive, + /// this reserved slot may be held up for a long time. As a result, it's advisable to set + /// `bound` to be at least the maximum number of concurrent requests the `Buffer` will see. + /// If you do not, all the slots in the buffer may be held up by futures that have just called + /// `poll_ready` but will not issue a `call`, which prevents other senders from issuing new + /// requests. + pub fn new(service: T, bound: usize) -> Self + where + T: Send + 'static, + T::Future: Send, + T::Error: Send + Sync, + Request: Send + 'static, + { + let (tx, rx) = mpsc::channel(bound); + let (handle, worker) = Worker::new(service, rx); + tokio::spawn(worker); + Buffer { tx, handle } + } + + /// Creates a new `Buffer` wrapping `service`, but returns the background worker. + /// + /// This is useful if you do not want to spawn directly onto the `tokio` runtime + /// but instead want to use your own executor. This will return the `Buffer` and + /// the background `Worker` that you can then spawn. + pub fn pair(service: T, bound: usize) -> (Buffer, Worker) + where + T: Send + 'static, + T::Error: Send + Sync, + Request: Send + 'static, + { + let (tx, rx) = mpsc::channel(bound); + let (handle, worker) = Worker::new(service, rx); + (Buffer { tx, handle }, worker) + } + + fn get_worker_error(&self) -> crate::BoxError { + self.handle.get_error_on_closed() + } +} + +impl Service for Buffer +where + T: Service, + T::Error: Into, +{ + type Response = T::Response; + type Error = crate::BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // If the inner service has errored, then we error here. + if let Err(_) = ready!(self.tx.poll_ready(cx)) { + Poll::Ready(Err(self.get_worker_error())) + } else { + Poll::Ready(Ok(())) + } + } + + fn call(&mut self, request: Request) -> Self::Future { + // TODO: + // ideally we'd poll_ready again here so we don't allocate the oneshot + // if the try_send is about to fail, but sadly we can't call poll_ready + // outside of task context. + let (tx, rx) = oneshot::channel(); + + // get the current Span so that we can explicitly propagate it to the worker + // if we didn't do this, events on the worker related to this span wouldn't be counted + // towards that span since the worker would have no way of entering it. + let span = tracing::Span::current(); + tracing::trace!(parent: &span, "sending request to buffer worker"); + match self.tx.try_send(Message { request, span, tx }) { + Err(mpsc::error::TrySendError::Closed(_)) => { + ResponseFuture::failed(self.get_worker_error()) + } + Err(mpsc::error::TrySendError::Full(_)) => { + // When `mpsc::Sender::poll_ready` returns `Ready`, a slot + // in the channel is reserved for the handle. Other `Sender` + // handles may not send a message using that slot. This + // guarantees capacity for `request`. + // + // Given this, the only way to hit this code path is if + // `poll_ready` has not been called & `Ready` returned. + panic!("buffer full; poll_ready must be called first"); + } + Ok(_) => ResponseFuture::new(rx), + } + } +} + +impl Clone for Buffer +where + T: Service, +{ + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + handle: self.handle.clone(), + } + } +} diff --git a/tower-batch/src/worker.rs b/tower-batch/src/worker.rs new file mode 100644 index 00000000..b77b6dc4 --- /dev/null +++ b/tower-batch/src/worker.rs @@ -0,0 +1,228 @@ +use super::{ + error::{Closed, ServiceError}, + message::Message, +}; +use futures_core::ready; +use pin_project::pin_project; +use std::sync::{Arc, Mutex}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::mpsc; +use tower::Service; + +/// Task that handles processing the buffer. This type should not be used +/// directly, instead `Buffer` requires an `Executor` that can accept this task. +/// +/// The struct is `pub` in the private module and the type is *not* re-exported +/// as part of the public API. This is the "sealed" pattern to include "private" +/// types in public traits that are not meant for consumers of the library to +/// implement (only call). +#[pin_project] +#[derive(Debug)] +pub struct Worker +where + T: Service, + T::Error: Into, +{ + current_message: Option>, + rx: mpsc::Receiver>, + service: T, + finish: bool, + failed: Option, + handle: Handle, +} + +/// Get the error out +#[derive(Debug)] +pub(crate) struct Handle { + inner: Arc>>, +} + +impl Worker +where + T: Service, + T::Error: Into, +{ + pub(crate) fn new( + service: T, + rx: mpsc::Receiver>, + ) -> (Handle, Worker) { + let handle = Handle { + inner: Arc::new(Mutex::new(None)), + }; + + let worker = Worker { + current_message: None, + finish: false, + failed: None, + rx, + service, + handle: handle.clone(), + }; + + (handle, worker) + } + + /// Return the next queued Message that hasn't been canceled. + /// + /// If a `Message` is returned, the `bool` is true if this is the first time we received this + /// message, and false otherwise (i.e., we tried to forward it to the backing service before). + fn poll_next_msg( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, bool)>> { + if self.finish { + // We've already received None and are shutting down + return Poll::Ready(None); + } + + tracing::trace!("worker polling for next message"); + if let Some(mut msg) = self.current_message.take() { + // poll_closed returns Poll::Ready is the receiver is dropped. + // Returning Pending means it is still alive, so we should still + // use it. + if msg.tx.poll_closed(cx).is_pending() { + tracing::trace!("resuming buffered request"); + return Poll::Ready(Some((msg, false))); + } + + tracing::trace!("dropping cancelled buffered request"); + } + + // Get the next request + while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) { + if msg.tx.poll_closed(cx).is_pending() { + tracing::trace!("processing new request"); + return Poll::Ready(Some((msg, true))); + } + // Otherwise, request is canceled, so pop the next one. + tracing::trace!("dropping cancelled request"); + } + + Poll::Ready(None) + } + + fn failed(&mut self, error: crate::BoxError) { + // The underlying service failed when we called `poll_ready` on it with the given `error`. We + // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in + // an `Arc`, send that `Arc` to all pending requests, and store it so that subsequent + // requests will also fail with the same error. + + // Note that we need to handle the case where some handle is concurrently trying to send us + // a request. We need to make sure that *either* the send of the request fails *or* it + // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the + // case where we send errors to all outstanding requests, and *then* the caller sends its + // request. We do this by *first* exposing the error, *then* closing the channel used to + // send more requests (so the client will see the error when the send fails), and *then* + // sending the error to all outstanding requests. + let error = ServiceError::new(error); + + let mut inner = self.handle.inner.lock().unwrap(); + + if inner.is_some() { + // Future::poll was called after we've already errored out! + return; + } + + *inner = Some(error.clone()); + drop(inner); + + self.rx.close(); + + // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None), + // which will trigger the `self.finish == true` phase. We just need to make sure that any + // requests that we receive before we've exhausted the receiver receive the error: + self.failed = Some(error); + } +} + +impl Future for Worker +where + T: Service, + T::Error: Into, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.finish { + return Poll::Ready(()); + } + + loop { + match ready!(self.poll_next_msg(cx)) { + Some((msg, first)) => { + let _guard = msg.span.enter(); + if let Some(ref failed) = self.failed { + tracing::trace!("notifying caller about worker failure"); + let _ = msg.tx.send(Err(failed.clone())); + continue; + } + + // Wait for the service to be ready + tracing::trace!( + resumed = !first, + message = "worker received request; waiting for service readiness" + ); + match self.service.poll_ready(cx) { + Poll::Ready(Ok(())) => { + tracing::debug!(service.ready = true, message = "processing request"); + let response = self.service.call(msg.request); + + // Send the response future back to the sender. + // + // An error means the request had been canceled in-between + // our calls, the response future will just be dropped. + tracing::trace!("returning response future"); + let _ = msg.tx.send(Ok(response)); + } + Poll::Pending => { + tracing::trace!(service.ready = false, message = "delay"); + // Put out current message back in its slot. + drop(_guard); + self.current_message = Some(msg); + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + let error = e.into(); + tracing::debug!({ %error }, "service failed"); + drop(_guard); + self.failed(error); + let _ = msg.tx.send(Err(self + .failed + .as_ref() + .expect("Worker::failed did not set self.failed?") + .clone())); + } + } + } + None => { + // No more more requests _ever_. + self.finish = true; + return Poll::Ready(()); + } + } + } + } +} + +impl Handle { + pub(crate) fn get_error_on_closed(&self) -> crate::BoxError { + self.inner + .lock() + .unwrap() + .as_ref() + .map(|svc_err| svc_err.clone().into()) + .unwrap_or_else(|| Closed::new().into()) + } +} + +impl Clone for Handle { + fn clone(&self) -> Handle { + Handle { + inner: self.inner.clone(), + } + } +}