From fc4b8c1e707e2b425549f6cf9f39fdf27ab439a1 Mon Sep 17 00:00:00 2001 From: Jane Lusby Date: Mon, 15 Mar 2021 10:21:29 -0700 Subject: [PATCH] add basic test for batch waker behaviour --- Cargo.lock | 85 +++++++++++++++++++++++- tower-batch/Cargo.toml | 2 + tower-batch/src/service.rs | 4 +- tower-batch/src/worker.rs | 31 ++++++++- tower-batch/tests/worker.rs | 124 ++++++++++++++++++++++++++++++++++++ 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 tower-batch/tests/worker.rs diff --git a/Cargo.lock b/Cargo.lock index d2d76740..d6705446 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -131,6 +131,27 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +[[package]] +name = "async-stream" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3670df70cbc01729f901f94c887814b3c68db038aad1329a418bae178bc5295c" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3548b8efc9f8e8a5a0a2808c5bd8451a9031b9e5b879a79590304ae928b0a70" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.7", + "syn 1.0.60", +] + [[package]] name = "atomic-shim" version = "0.1.0" @@ -458,6 +479,12 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0dcbc35f504eb6fc275a6d20e4ebcda18cf50d40ba6fabff8c711fa16cb3b16" +[[package]] +name = "bytes" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" + [[package]] name = "canonical-path" version = "2.0.2" @@ -3283,6 +3310,16 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "tokio" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d56477f6ed99e10225f38f9f75f872f29b8b8bd8c0b946f63345bb144e9eeda" +dependencies = [ + "autocfg", + "pin-project-lite 0.2.4", +] + [[package]] name = "tokio-macros" version = "0.3.2" @@ -3306,6 +3343,30 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c535f53c0cfa1acace62995a8994fc9cc1f12d202420da96ff306ee24d576469" +dependencies = [ + "futures-core", + "pin-project-lite 0.2.4", + "tokio 1.3.0", +] + +[[package]] +name = "tokio-test" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f58403903e94d4bc56805e46597fced893410b2e753e229d3f7f22423ea03f67" +dependencies = [ + "async-stream", + "bytes 1.0.1", + "futures-core", + "tokio 1.3.0", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.3.1" @@ -3367,7 +3428,7 @@ dependencies = [ "hdrhistogram 6.3.4", "pin-project 1.0.2", "tokio 0.3.6", - "tower-layer", + "tower-layer 0.3.0", "tower-service", "tracing", ] @@ -3383,8 +3444,10 @@ dependencies = [ "pin-project 0.4.27", "rand 0.7.3", "tokio 0.3.6", + "tokio-test", "tower", "tower-fallback", + "tower-test", "tracing", "tracing-futures", "zebra-test", @@ -3407,12 +3470,32 @@ name = "tower-layer" version = "0.3.0" source = "git+https://github.com/tower-rs/tower?rev=d4d1c67c6a0e4213a52abcc2b9df6cc58276ee39#d4d1c67c6a0e4213a52abcc2b9df6cc58276ee39" +[[package]] +name = "tower-layer" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" + [[package]] name = "tower-service" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e987b6bf443f4b5b3b6f38704195592cca41c5bb7aedd3c3693c7081f8289860" +[[package]] +name = "tower-test" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4546773ffeab9e4ea02b8872faa49bb616a80a7da66afc2f32688943f97efa7" +dependencies = [ + "futures-util", + "pin-project 1.0.2", + "tokio 1.3.0", + "tokio-test", + "tower-layer 0.3.1", + "tower-service", +] + [[package]] name = "tracing" version = "0.1.25" diff --git a/tower-batch/Cargo.toml b/tower-batch/Cargo.toml index 716dcf7b..10496e01 100644 --- a/tower-batch/Cargo.toml +++ b/tower-batch/Cargo.toml @@ -22,3 +22,5 @@ tracing = "0.1.25" zebra-test = { path = "../zebra-test/" } tower-fallback = { path = "../tower-fallback/" } color-eyre = "0.5.10" +tokio-test = "0.4.1" +tower-test = "0.4.0" diff --git a/tower-batch/src/service.rs b/tower-batch/src/service.rs index 176d63fa..7f99aa19 100644 --- a/tower-batch/src/service.rs +++ b/tower-batch/src/service.rs @@ -101,9 +101,9 @@ where // We choose a bound that allows callers to check readiness for every item in // a batch, then actually submit those items. let bound = max_items; - let semaphore = Semaphore::new(bound); + let (semaphore, close) = Semaphore::new_with_close(bound); - let (handle, worker) = Worker::new(service, rx, max_items, max_latency); + let (handle, worker) = Worker::new(service, rx, max_items, max_latency, close); let batch = Batch { tx, semaphore, diff --git a/tower-batch/src/worker.rs b/tower-batch/src/worker.rs index ff793951..4137dfed 100644 --- a/tower-batch/src/worker.rs +++ b/tower-batch/src/worker.rs @@ -1,4 +1,7 @@ -use std::sync::{Arc, Mutex}; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, +}; use futures::future::TryFutureExt; use pin_project::pin_project; @@ -10,6 +13,8 @@ use tokio::{ use tower::{Service, ServiceExt}; use tracing_futures::Instrument; +use crate::semaphore; + use super::{ error::{Closed, ServiceError}, message::{self, Message}, @@ -23,7 +28,7 @@ use super::{ /// as part of the public API. This is the "sealed" pattern to include "private" /// types in public traits that are not meant for consumers of the library to /// implement (only call). -#[pin_project] +#[pin_project(PinnedDrop)] #[derive(Debug)] pub struct Worker where @@ -36,6 +41,7 @@ where handle: Handle, max_items: usize, max_latency: std::time::Duration, + close: Option, } /// Get the error out @@ -54,6 +60,7 @@ where rx: mpsc::UnboundedReceiver>, max_items: usize, max_latency: std::time::Duration, + close: semaphore::Close, ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), @@ -66,6 +73,7 @@ where failed: None, max_items, max_latency, + close: Some(close), }; (handle, worker) @@ -88,6 +96,12 @@ where .as_ref() .expect("Worker::failed did not set self.failed?") .clone())); + + // Wake any tasks waiting on channel capacity. + if let Some(close) = self.close.take() { + tracing::debug!("waking pending tasks"); + close.close(); + } } } } @@ -221,3 +235,16 @@ impl Clone for Handle { } } } + +#[pin_project::pinned_drop] +impl PinnedDrop for Worker +where + T: Service>, + T::Error: Into, +{ + fn drop(mut self: Pin<&mut Self>) { + if let Some(close) = self.as_mut().close.take() { + close.close(); + } + } +} diff --git a/tower-batch/tests/worker.rs b/tower-batch/tests/worker.rs new file mode 100644 index 00000000..e5d13fce --- /dev/null +++ b/tower-batch/tests/worker.rs @@ -0,0 +1,124 @@ +use std::time::Duration; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, task}; +use tower::{Service, ServiceExt}; +use tower_batch::{error, Batch}; +use tower_test::mock; + +#[tokio::test] +async fn wakes_pending_waiters_on_close() { + zebra_test::init(); + + let (service, mut handle) = mock::pair::<_, ()>(); + + let (mut service, worker) = Batch::pair(service, 1, Duration::from_secs(1)); + let mut worker = task::spawn(worker.run()); + + // // keep the request in the worker + handle.allow(0); + let service1 = service.ready_and().await.unwrap(); + let poll = worker.poll(); + assert_pending!(poll); + let mut response = task::spawn(service1.call(())); + + let mut service1 = service.clone(); + let mut ready1 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready1.poll(), "no capacity"); + + let mut service1 = service.clone(); + let mut ready2 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready2.poll(), "no capacity"); + + // kill the worker task + drop(worker); + + let err = assert_ready_err!(response.poll()); + assert!( + err.is::(), + "response should fail with a Closed, got: {:?}", + err + ); + + assert!( + ready1.is_woken(), + "dropping worker should wake ready task 1" + ); + let err = assert_ready_err!(ready1.poll()); + assert!( + err.is::(), + "ready 1 should fail with a Closed, got: {:?}", + err + ); + + assert!( + ready2.is_woken(), + "dropping worker should wake ready task 2" + ); + let err = assert_ready_err!(ready1.poll()); + assert!( + err.is::(), + "ready 2 should fail with a Closed, got: {:?}", + err + ); +} + +#[tokio::test] +async fn wakes_pending_waiters_on_failure() { + zebra_test::init(); + + let (service, mut handle) = mock::pair::<_, ()>(); + + let (mut service, worker) = Batch::pair(service, 1, Duration::from_secs(1)); + let mut worker = task::spawn(worker.run()); + + // keep the request in the worker + handle.allow(0); + let service1 = service.ready_and().await.unwrap(); + assert_pending!(worker.poll()); + let mut response = task::spawn(service1.call("hello")); + + let mut service1 = service.clone(); + let mut ready1 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready1.poll(), "no capacity"); + + let mut service1 = service.clone(); + let mut ready2 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready2.poll(), "no capacity"); + + // fail the inner service + handle.send_error("foobar"); + // worker task terminates + assert_ready!(worker.poll()); + + let err = assert_ready_err!(response.poll()); + assert!( + err.is::(), + "response should fail with a ServiceError, got: {:?}", + err + ); + + assert!( + ready1.is_woken(), + "dropping worker should wake ready task 1" + ); + let err = assert_ready_err!(ready1.poll()); + assert!( + err.is::(), + "ready 1 should fail with a ServiceError, got: {:?}", + err + ); + + assert!( + ready2.is_woken(), + "dropping worker should wake ready task 2" + ); + let err = assert_ready_err!(ready1.poll()); + assert!( + err.is::(), + "ready 2 should fail with a ServiceError, got: {:?}", + err + ); +}