diff --git a/zebra-test/src/transcript.rs b/zebra-test/src/transcript.rs index 0a34193e..d5d7442e 100644 --- a/zebra-test/src/transcript.rs +++ b/zebra-test/src/transcript.rs @@ -1,12 +1,12 @@ //! A [`Service`](tower::Service) implementation based on a fixed transcript. -use color_eyre::eyre::{eyre, Report}; +use color_eyre::eyre::{ensure, eyre, Report}; use futures::future::{ready, Ready}; use std::{ fmt::Debug, task::{Context, Poll}, }; -use tower::Service; +use tower::{Service, ServiceExt}; pub struct Transcript where @@ -24,6 +24,32 @@ where } } +impl Transcript +where + I: Iterator, + R: Debug, + S: Debug + Eq, +{ + pub async fn check(mut self, mut to_check: C) -> Result<(), Report> + where + C: Service, + C::Error: Debug, + { + while let Some((req, expected_rsp)) = self.messages.next() { + // These unwraps could propagate errors with the correct + // bound on C::Error + let rsp = to_check.ready_and().await.unwrap().call(req).await.unwrap(); + ensure!( + rsp == expected_rsp, + "Expected {:?}, got {:?}", + expected_rsp, + rsp + ); + } + Ok(()) + } +} + impl Service for Transcript where R: Debug + Eq, diff --git a/zebra-test/tests/transcript.rs b/zebra-test/tests/transcript.rs index 3d20ec43..687ecd5c 100644 --- a/zebra-test/tests/transcript.rs +++ b/zebra-test/tests/transcript.rs @@ -32,3 +32,10 @@ async fn transcript_errors_wrong_request() { ); assert!(svc.ready_and().await.unwrap().call("bad").await.is_err()); } + +#[tokio::test] +async fn self_check() { + let t1 = Transcript::from(TRANSCRIPT_DATA.iter().cloned()); + let t2 = Transcript::from(TRANSCRIPT_DATA.iter().cloned()); + assert!(t1.check(t2).await.is_ok()); +}