diff --git a/Cargo.lock b/Cargo.lock index fcd9dcd3..06c63f66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2403,8 +2403,12 @@ dependencies = [ name = "zebra-test" version = "0.1.0" dependencies = [ + "color-eyre", + "futures", "hex", "lazy_static", + "tokio", + "tower", ] [[package]] diff --git a/zebra-test/Cargo.toml b/zebra-test/Cargo.toml index 41b888a2..dd9d6dc9 100644 --- a/zebra-test/Cargo.toml +++ b/zebra-test/Cargo.toml @@ -10,3 +10,9 @@ edition = "2018" [dependencies] hex = "0.4.2" lazy_static = "1.4.0" +tower = "0.3.1" +futures = "0.3.5" +color-eyre = "0.5" + +[dev-dependencies] +tokio = { version = "0.2", features = ["full"] } \ No newline at end of file diff --git a/zebra-test/src/lib.rs b/zebra-test/src/lib.rs index c31ddd2d..1a157d56 100644 --- a/zebra-test/src/lib.rs +++ b/zebra-test/src/lib.rs @@ -1 +1,4 @@ +//! Miscellaneous test code for Zebra. + +pub mod transcript; pub mod vectors; diff --git a/zebra-test/src/transcript.rs b/zebra-test/src/transcript.rs new file mode 100644 index 00000000..0a34193e --- /dev/null +++ b/zebra-test/src/transcript.rs @@ -0,0 +1,55 @@ +//! A [`Service`](tower::Service) implementation based on a fixed transcript. + +use color_eyre::eyre::{eyre, Report}; +use futures::future::{ready, Ready}; +use std::{ + fmt::Debug, + task::{Context, Poll}, +}; +use tower::Service; + +pub struct Transcript +where + I: Iterator, +{ + messages: I, +} + +impl From for Transcript +where + I: Iterator, +{ + fn from(messages: I) -> Self { + Self { messages } + } +} + +impl Service for Transcript +where + R: Debug + Eq, + I: Iterator, +{ + type Response = S; + type Error = Report; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: R) -> Self::Future { + if let Some((expected_request, response)) = self.messages.next() { + if request == expected_request { + ready(Ok(response)) + } else { + ready(Err(eyre!( + "Expected {:?}, got {:?}", + expected_request, + request + ))) + } + } else { + ready(Err(eyre!("Got request after transcript ended"))) + } + } +} diff --git a/zebra-test/tests/transcript.rs b/zebra-test/tests/transcript.rs new file mode 100644 index 00000000..3d20ec43 --- /dev/null +++ b/zebra-test/tests/transcript.rs @@ -0,0 +1,34 @@ +use tower::{Service, ServiceExt}; + +use zebra_test::transcript::Transcript; + +const TRANSCRIPT_DATA: [(&'static str, &'static str); 4] = [ + ("req1", "rsp1"), + ("req2", "rsp2"), + ("req3", "rsp3"), + ("req4", "rsp4"), +]; + +#[tokio::test] +async fn transcript_returns_responses_and_ends() { + let mut svc = Transcript::from(TRANSCRIPT_DATA.iter().cloned()); + + for (req, rsp) in TRANSCRIPT_DATA.iter() { + assert_eq!( + svc.ready_and().await.unwrap().call(req).await.unwrap(), + *rsp, + ); + } + assert!(svc.ready_and().await.unwrap().call("end").await.is_err()); +} + +#[tokio::test] +async fn transcript_errors_wrong_request() { + let mut svc = Transcript::from(TRANSCRIPT_DATA.iter().cloned()); + + assert_eq!( + svc.ready_and().await.unwrap().call("req1").await.unwrap(), + "rsp1", + ); + assert!(svc.ready_and().await.unwrap().call("bad").await.is_err()); +}