From 4be0a8edc368b95ae62482515d45605fdaf3084a Mon Sep 17 00:00:00 2001 From: Henry de Valence Date: Tue, 14 Jul 2020 16:21:01 -0700 Subject: [PATCH] tower-fallback: add implementation. --- Cargo.lock | 19 ++++ Cargo.toml | 1 + tower-fallback/Cargo.toml | 17 ++++ tower-fallback/src/future.rs | 158 +++++++++++++++++++++++++++++++ tower-fallback/src/lib.rs | 9 ++ tower-fallback/src/service.rs | 54 +++++++++++ tower-fallback/tests/fallback.rs | 33 +++++++ 7 files changed, 291 insertions(+) create mode 100644 tower-fallback/Cargo.toml create mode 100644 tower-fallback/src/future.rs create mode 100644 tower-fallback/src/lib.rs create mode 100644 tower-fallback/src/service.rs create mode 100644 tower-fallback/tests/fallback.rs diff --git a/Cargo.lock b/Cargo.lock index 1d37855d..c62cc6cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -489,6 +489,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "either" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" + [[package]] name = "equihash" version = "0.1.0" @@ -2144,6 +2150,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-fallback" +version = "0.1.0" +dependencies = [ + "either", + "futures-core", + "pin-project", + "tokio", + "tower", + "tracing", + "zebra-test", +] + [[package]] name = "tower-layer" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 79dc108b..a7ee0ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "zebra-test", "zebra-utils", "tower-batch", + "tower-fallback", ] [profile.dev] diff --git a/tower-fallback/Cargo.toml b/tower-fallback/Cargo.toml new file mode 100644 index 00000000..1ed992d2 --- /dev/null +++ b/tower-fallback/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "tower-fallback" +version = "0.1.0" +authors = ["Zcash Foundation "] +license = "MIT" +edition = "2018" + +[dependencies] +either = "1.5" +tower = "0.3" +futures-core = "0.3.5" +pin-project = "0.4.20" +tracing = "0.1" + +[dev-dependencies] +zebra-test = { path = "../zebra-test/" } +tokio = { version = "0.2", features = ["full"]} diff --git a/tower-fallback/src/future.rs b/tower-fallback/src/future.rs new file mode 100644 index 00000000..312ac66e --- /dev/null +++ b/tower-fallback/src/future.rs @@ -0,0 +1,158 @@ +//! Future types for the `Fallback` middleware. + +use std::{ + fmt::Debug, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use either::Either; +use futures_core::ready; +use pin_project::pin_project; +use tower::Service; + +/// Future that completes either with the first service's successful response, or +/// with the second service's response. +#[pin_project] +pub struct ResponseFuture +where + S1: Service, + S2: Service>::Response>, +{ + #[pin] + state: ResponseState, +} + +#[pin_project(project_replace, project = ResponseStateProj)] +enum ResponseState +where + S1: Service, + S2: Service, +{ + PollResponse1 { + #[pin] + fut: S1::Future, + req: Request, + svc2: S2, + }, + PollReady2 { + req: Request, + svc2: S2, + }, + PollResponse2 { + #[pin] + fut: S2::Future, + }, + // Placeholder value to swap into the pin projection of the enum so we can take ownership of the fields. + Tmp, +} + +impl ResponseFuture +where + S1: Service, + S2: Service>::Response>, +{ + pub(crate) fn new(fut: S1::Future, req: Request, svc2: S2) -> Self { + ResponseFuture { + state: ResponseState::PollResponse1 { fut, req, svc2 }, + } + } +} + +impl Future for ResponseFuture +where + S1: Service, + S2: Service>::Response>, +{ + type Output = Result< + >::Response, + Either<>::Error, >::Error>, + >; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + match this.state.as_mut().project() { + ResponseStateProj::PollResponse1 { fut, .. } => match ready!(fut.poll(cx)) { + Ok(rsp) => return Poll::Ready(Ok(rsp)), + Err(_) => { + tracing::debug!("got error from svc1, retrying on svc2"); + if let __ResponseStateProjectionOwned::PollResponse1 { req, svc2, .. } = + this.state.as_mut().project_replace(ResponseState::Tmp) + { + this.state.set(ResponseState::PollReady2 { req, svc2 }); + } else { + unreachable!(); + } + } + }, + ResponseStateProj::PollReady2 { svc2, .. } => match ready!(svc2.poll_ready(cx)) { + Err(e) => return Poll::Ready(Err(Either::Right(e))), + Ok(()) => { + if let __ResponseStateProjectionOwned::PollReady2 { mut svc2, req } = + this.state.as_mut().project_replace(ResponseState::Tmp) + { + this.state.set(ResponseState::PollResponse2 { + fut: svc2.call(req), + }); + } else { + unreachable!(); + } + } + }, + ResponseStateProj::PollResponse2 { fut } => { + return fut.poll(cx).map_err(Either::Right) + } + ResponseStateProj::Tmp => unreachable!(), + } + } + } +} + +impl Debug for ResponseFuture +where + S1: Service, + S2: Service>::Response>, + Request: Debug, + S1::Future: Debug, + S2: Debug, + S2::Future: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseFuture") + .field("state", &self.state) + .finish() + } +} + +impl Debug for ResponseState +where + S1: Service, + S2: Service>::Response>, + Request: Debug, + S1::Future: Debug, + S2: Debug, + S2::Future: Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResponseState::PollResponse1 { fut, req, svc2 } => f + .debug_struct("ResponseState::PollResponse1") + .field("fut", fut) + .field("req", req) + .field("svc2", svc2) + .finish(), + ResponseState::PollReady2 { req, svc2 } => f + .debug_struct("ResponseState::PollReady2") + .field("req", req) + .field("svc2", svc2) + .finish(), + ResponseState::PollResponse2 { fut } => f + .debug_struct("ResponseState::PollResponse2") + .field("fut", fut) + .finish(), + ResponseState::Tmp => unreachable!(), + } + } +} diff --git a/tower-fallback/src/lib.rs b/tower-fallback/src/lib.rs new file mode 100644 index 00000000..705df28e --- /dev/null +++ b/tower-fallback/src/lib.rs @@ -0,0 +1,9 @@ +/// A service combinator that sends requests to a first service, then retries +/// processing on a second fallback service if the first service errors. +/// +/// TODO: similar code exists in linkerd and could be upstreamed into tower +pub mod future; +mod service; + +pub use self::service::Fallback; +pub use either::Either; diff --git a/tower-fallback/src/service.rs b/tower-fallback/src/service.rs new file mode 100644 index 00000000..cd171ff4 --- /dev/null +++ b/tower-fallback/src/service.rs @@ -0,0 +1,54 @@ +use super::future::ResponseFuture; + +use either::Either; +use std::task::{Context, Poll}; +use tower::Service; + +/// Provides fallback processing on a second service if the first service returned an error. +#[derive(Debug)] +pub struct Fallback +where + S2: Clone, +{ + svc1: S1, + svc2: S2, +} + +impl Clone for Fallback { + fn clone(&self) -> Self { + Self { + svc1: self.svc1.clone(), + svc2: self.svc2.clone(), + } + } +} + +impl Fallback { + /// Creates a new `Fallback` wrapping a pair of services. + /// + /// Requests are processed on `svc1`, and retried on `svc2` if `svc1` errored. + pub fn new(svc1: S1, svc2: S2) -> Self { + Self { svc1, svc2 } + } +} + +impl Service for Fallback +where + S1: Service, + S2: Service>::Response>, + S2: Clone, + Request: Clone, +{ + type Response = >::Response; + type Error = Either<>::Error, >::Error>; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.svc1.poll_ready(cx).map_err(Either::Left) + } + + fn call(&mut self, request: Request) -> Self::Future { + let request2 = request.clone(); + ResponseFuture::new(self.svc1.call(request), request2, self.svc2.clone()) + } +} diff --git a/tower-fallback/tests/fallback.rs b/tower-fallback/tests/fallback.rs new file mode 100644 index 00000000..3420fdcf --- /dev/null +++ b/tower-fallback/tests/fallback.rs @@ -0,0 +1,33 @@ +use tower::{service_fn, Service, ServiceExt}; +use tower_fallback::{Either, Fallback}; + +#[tokio::test] +async fn fallback() { + zebra_test::init(); + + // we'd like to use Transcript here but it can't handle errors :( + + let svc1 = service_fn(|val: u64| async move { + if val < 10 { + Ok(val) + } else { + Err("too big value on svc1") + } + }); + let svc2 = service_fn(|val: u64| async move { + if val < 20 { + Ok(100 + val) + } else { + Err("too big value on svc2") + } + }); + + let mut svc = Fallback::new(svc1, svc2); + + assert_eq!(svc.ready_and().await.unwrap().call(1).await, Ok(1)); + assert_eq!(svc.ready_and().await.unwrap().call(11).await, Ok(111)); + assert_eq!( + svc.ready_and().await.unwrap().call(21).await, + Err(Either::Right("too big value on svc2")) + ); +}