diff --git a/Cargo.lock b/Cargo.lock index af9cd465..76d4709a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1927,6 +1927,7 @@ dependencies = [ name = "tower-batch" version = "0.1.0" dependencies = [ + "color-eyre", "ed25519-zebra", "futures", "futures-core", diff --git a/tower-batch/Cargo.toml b/tower-batch/Cargo.toml index d1fda004..59753c21 100644 --- a/tower-batch/Cargo.toml +++ b/tower-batch/Cargo.toml @@ -21,3 +21,4 @@ tokio = { version = "0.2", features = ["full"]} tracing-error = "0.1.2" tracing-subscriber = "0.2.5" tracing = "0.1.15" +color-eyre = "0.3.4" diff --git a/tower-batch/src/error.rs b/tower-batch/src/error.rs index 418957fc..7d35a8a1 100644 --- a/tower-batch/src/error.rs +++ b/tower-batch/src/error.rs @@ -1,47 +1,12 @@ //! Error types for the `Batch` middleware. -use crate::BoxError; -use std::{fmt, sync::Arc}; - -/// An error produced by a `Service` wrapped by a `Batch`. -#[derive(Debug)] -pub struct ServiceError { - inner: Arc, -} +use std::fmt; /// An error produced when the batch worker closes unexpectedly. pub struct Closed { _p: (), } -// ===== impl ServiceError ===== - -impl ServiceError { - pub(crate) fn new(inner: BoxError) -> ServiceError { - let inner = Arc::new(inner); - ServiceError { inner } - } - - // Private to avoid exposing `Clone` trait as part of the public API - pub(crate) fn clone(&self) -> ServiceError { - ServiceError { - inner: self.inner.clone(), - } - } -} - -impl fmt::Display for ServiceError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "batching service failed: {}", self.inner) - } -} - -impl std::error::Error for ServiceError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&**self.inner) - } -} - // ===== impl Closed ===== impl Closed { diff --git a/tower-batch/src/future.rs b/tower-batch/src/future.rs index ed96ce3f..300f79bf 100644 --- a/tower-batch/src/future.rs +++ b/tower-batch/src/future.rs @@ -4,47 +4,68 @@ use super::{error::Closed, message}; use futures_core::ready; use pin_project::pin_project; use std::{ + fmt::Debug, future::Future, pin::Pin, task::{Context, Poll}, }; +use tower::Service; /// Future that completes when the batch processing is complete. #[pin_project] #[derive(Debug)] -pub struct ResponseFuture { +pub struct ResponseFuture +where + T: Service>, +{ #[pin] - state: ResponseState, + state: ResponseState, } #[pin_project(project = ResponseStateProj)] -#[derive(Debug)] -enum ResponseState { - Failed(Option), - Rx(#[pin] message::Rx), - Poll(#[pin] T), +enum ResponseState +where + T: Service>, +{ + Failed(Option), + Rx(#[pin] message::Rx), + Poll(#[pin] T::Future), } -impl ResponseFuture { - pub(crate) fn new(rx: message::Rx) -> Self { +impl Debug for ResponseState +where + T: Service>, +{ + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + todo!() + } +} + +impl ResponseFuture +where + T: Service>, +{ + pub(crate) fn new(rx: message::Rx) -> Self { ResponseFuture { state: ResponseState::Rx(rx), } } - pub(crate) fn failed(err: crate::BoxError) -> Self { + pub(crate) fn failed(err: E) -> Self { ResponseFuture { state: ResponseState::Failed(Some(err)), } } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - F: Future>, - E: Into, + S: Service>, + S::Future: Future>, + S::Error: Into, + crate::error::Closed: Into, { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); diff --git a/tower-batch/src/layer.rs b/tower-batch/src/layer.rs index 8fda7ba3..a7631ac5 100644 --- a/tower-batch/src/layer.rs +++ b/tower-batch/src/layer.rs @@ -9,13 +9,14 @@ use tower::Service; /// which means that this layer can only be used on the Tokio runtime. /// /// See the module documentation for more details. -pub struct BatchLayer { +pub struct BatchLayer { max_items: usize, max_latency: std::time::Duration, _p: PhantomData, + _e: PhantomData, } -impl BatchLayer { +impl BatchLayer { /// Creates a new `BatchLayer`. /// /// The wrapper is responsible for telling the inner service when to flush a @@ -28,25 +29,28 @@ impl BatchLayer { max_items, max_latency, _p: PhantomData, + _e: PhantomData, } } } -impl Layer for BatchLayer +impl Layer for BatchLayer where S: Service> + Send + 'static, S::Future: Send, - S::Error: Into + Send + Sync, + S::Error: Clone + Into + Send + Sync, Request: Send + 'static, + E: Clone + Send + 'static, + crate::error::Closed: Into, { - type Service = Batch; + type Service = Batch; fn layer(&self, service: S) -> Self::Service { Batch::new(service, self.max_items, self.max_latency) } } -impl fmt::Debug for BatchLayer { +impl fmt::Debug for BatchLayer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("BufferLayer") .field("max_items", &self.max_items) diff --git a/tower-batch/src/message.rs b/tower-batch/src/message.rs index dc73a6ad..7e433592 100644 --- a/tower-batch/src/message.rs +++ b/tower-batch/src/message.rs @@ -1,16 +1,15 @@ -use super::error::ServiceError; use tokio::sync::oneshot; /// Message sent to the batch worker #[derive(Debug)] -pub(crate) struct Message { +pub(crate) struct Message { pub(crate) request: Request, - pub(crate) tx: Tx, + pub(crate) tx: Tx, pub(crate) span: tracing::Span, } /// Response sender -pub(crate) type Tx = oneshot::Sender>; +pub(crate) type Tx = oneshot::Sender>; /// Response receiver -pub(crate) type Rx = oneshot::Receiver>; +pub(crate) type Rx = oneshot::Receiver>; diff --git a/tower-batch/src/service.rs b/tower-batch/src/service.rs index 28c4e713..cba07b60 100644 --- a/tower-batch/src/service.rs +++ b/tower-batch/src/service.rs @@ -6,7 +6,10 @@ use super::{ }; use futures_core::ready; -use std::task::{Context, Poll}; +use std::{ + marker::PhantomData, + task::{Context, Poll}, +}; use tokio::sync::{mpsc, oneshot}; use tower::Service; @@ -14,18 +17,23 @@ use tower::Service; /// /// See the module documentation for more details. #[derive(Debug)] -pub struct Batch +pub struct Batch where T: Service>, { - tx: mpsc::Sender>, - handle: Handle, + tx: mpsc::Sender>, + handle: Handle, + _error_type: PhantomData, } -impl Batch +impl Batch where T: Service>, - T::Error: Into, + T::Error: Into, + E: Send + 'static, + crate::error::Closed: Into, + // crate::error::Closed: Into<>::Error> + Send + Sync + 'static, + // crate::error::ServiceError: Into<>::Error> + Send + Sync + 'static, { /// Creates a new `Batch` wrapping `service`. /// @@ -41,29 +49,35 @@ where where T: Send + 'static, T::Future: Send, - T::Error: Send + Sync, + T::Error: Send + Sync + Clone, Request: Send + 'static, { // XXX(hdevalence): is this bound good let (tx, rx) = mpsc::channel(1); let (handle, worker) = Worker::new(service, rx, max_items, max_latency); tokio::spawn(worker.run()); - Batch { tx, handle } + Batch { + tx, + handle, + _error_type: PhantomData, + } } - fn get_worker_error(&self) -> crate::BoxError { + fn get_worker_error(&self) -> E { self.handle.get_error_on_closed() } } -impl Service for Batch +impl Service for Batch where T: Service>, - T::Error: Into, + crate::error::Closed: Into, + T::Error: Into, + E: Send + 'static, { type Response = T::Response; - type Error = crate::BoxError; - type Future = ResponseFuture; + type Error = E; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // If the inner service has errored, then we error here. @@ -113,6 +127,7 @@ where Self { tx: self.tx.clone(), handle: self.handle.clone(), + _error_type: PhantomData, } } } diff --git a/tower-batch/src/worker.rs b/tower-batch/src/worker.rs index 8d4ab367..8b314ed7 100644 --- a/tower-batch/src/worker.rs +++ b/tower-batch/src/worker.rs @@ -1,11 +1,14 @@ use super::{ - error::{Closed, ServiceError}, + error::Closed, message::{self, Message}, BatchControl, }; use futures::future::TryFutureExt; use pin_project::pin_project; -use std::sync::{Arc, Mutex}; +use std::{ + marker::PhantomData, + sync::{Arc, Mutex}, +}; use tokio::{ stream::StreamExt, sync::mpsc, @@ -23,36 +26,37 @@ use tracing_futures::Instrument; /// implement (only call). #[pin_project] #[derive(Debug)] -pub struct Worker +pub struct Worker where T: Service>, - T::Error: Into, + T::Error: Into, { - rx: mpsc::Receiver>, + rx: mpsc::Receiver>, service: T, - failed: Option, - handle: Handle, + failed: Option, + handle: Handle, max_items: usize, max_latency: std::time::Duration, + _error_type: PhantomData, } /// Get the error out #[derive(Debug)] -pub(crate) struct Handle { - inner: Arc>>, +pub(crate) struct Handle { + inner: Arc>>, } -impl Worker +impl Worker where T: Service>, - T::Error: Into, + T::Error: Into + Clone, { pub(crate) fn new( service: T, - rx: mpsc::Receiver>, + rx: mpsc::Receiver>, max_items: usize, max_latency: std::time::Duration, - ) -> (Handle, Worker) { + ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), }; @@ -64,15 +68,16 @@ where failed: None, max_items, max_latency, + _error_type: PhantomData, }; (handle, worker) } - async fn process_req(&mut self, req: Request, tx: message::Tx) { - if let Some(ref failed) = self.failed { + async fn process_req(&mut self, req: Request, tx: message::Tx) { + if let Some(failed) = self.failed.clone() { tracing::trace!("notifying caller about worker failure"); - let _ = tx.send(Err(failed.clone())); + let _ = tx.send(Err(failed)); } else { match self.service.ready_and().await { Ok(svc) => { @@ -80,12 +85,11 @@ where let _ = tx.send(Ok(rsp)); } Err(e) => { - self.failed(e.into()); + self.failed(e); let _ = tx.send(Err(self .failed - .as_ref() - .expect("Worker::failed did not set self.failed?") - .clone())); + .clone() + .expect("Worker::failed did not set self.failed?"))); } } } @@ -98,7 +102,7 @@ where .and_then(|svc| svc.call(BatchControl::Flush)) .await { - self.failed(e.into()); + self.failed(e); } } @@ -165,7 +169,7 @@ where } } - fn failed(&mut self, error: crate::BoxError) { + fn failed(&mut self, error: T::Error) { // The underlying service failed when we called `poll_ready` on it with the given `error`. We // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in // an `Arc`, send that `Arc` to all pending requests, and store it so that subsequent @@ -178,7 +182,6 @@ where // request. We do this by *first* exposing the error, *then* closing the channel used to // send more requests (so the client will see the error when the send fails), and *then* // sending the error to all outstanding requests. - let error = ServiceError::new(error); let mut inner = self.handle.inner.lock().unwrap(); @@ -187,7 +190,7 @@ where return; } - *inner = Some(error.clone()); + *inner = Some(error.clone().into()); drop(inner); self.rx.close(); @@ -200,19 +203,21 @@ where } } -impl Handle { - pub(crate) fn get_error_on_closed(&self) -> crate::BoxError { +impl Handle +where + crate::error::Closed: Into, +{ + pub(crate) fn get_error_on_closed(&self) -> E { self.inner .lock() .unwrap() - .as_ref() - .map(|svc_err| svc_err.clone().into()) + .take() .unwrap_or_else(|| Closed::new().into()) } } -impl Clone for Handle { - fn clone(&self) -> Handle { +impl Clone for Handle { + fn clone(&self) -> Handle { Handle { inner: self.inner.clone(), } diff --git a/tower-batch/tests/ed25519.rs b/tower-batch/tests/ed25519.rs index d1599fc5..e28bbeff 100644 --- a/tower-batch/tests/ed25519.rs +++ b/tower-batch/tests/ed25519.rs @@ -131,31 +131,31 @@ where } #[tokio::test] -async fn batch_flushes_on_max_items() { +async fn batch_flushes_on_max_items() -> color_eyre::Result<()> { use tokio::time::timeout; install_tracing(); // Use a very long max_latency and a short timeout to check that // flushing is happening based on hitting max_items. - let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000)); - assert!( - timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)) - .await - .is_ok() - ) + let verifier = Batch::<_, _, color_eyre::Report>::new( + Ed25519Verifier::new(), + 10, + Duration::from_secs(1000), + ); + Ok(timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)).await?) } #[tokio::test] -async fn batch_flushes_on_max_latency() { +async fn batch_flushes_on_max_latency() -> color_eyre::Result<()> { use tokio::time::timeout; install_tracing(); // Use a very high max_items and a short timeout to check that // flushing is happening based on hitting max_latency. - let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500)); - assert!( - timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)) - .await - .is_ok() - ) + let verifier = Batch::<_, _, color_eyre::Report>::new( + Ed25519Verifier::new(), + 100, + Duration::from_millis(500), + ); + Ok(timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)).await?) }