diff --git a/zebra-network/src/lib.rs b/zebra-network/src/lib.rs index f658b8d6..b3a6cb2c 100644 --- a/zebra-network/src/lib.rs +++ b/zebra-network/src/lib.rs @@ -31,7 +31,7 @@ #![deny(missing_docs)] // Tracing causes false positives on this lint: // https://github.com/tokio-rs/tracing/issues/553 -#![allow(clippy::cognitive_complexity)] +#![allow(clippy::cognitive_complexity, clippy::try_err)] #[macro_use] extern crate pin_project; diff --git a/zebra-network/src/peer_set/initialize.rs b/zebra-network/src/peer_set/initialize.rs index 1327d50b..02f7577f 100644 --- a/zebra-network/src/peer_set/initialize.rs +++ b/zebra-network/src/peer_set/initialize.rs @@ -72,45 +72,38 @@ where let (peerset_tx, peerset_rx) = mpsc::channel::(100); // Create an mpsc channel for peerset demand signaling. let (mut demand_tx, demand_rx) = mpsc::channel::<()>(100); + let (handle_tx, handle_rx) = tokio::sync::oneshot::channel(); // Connect the rx end to a PeerSet, wrapping new peers in load instruments. - let peer_set = Buffer::new( - PeerSet::new( - PeakEwmaDiscover::new( - ServiceStream::new( - // ServiceStream interprets an error as stream termination, - // so discard any errored connections... - peerset_rx.filter(|result| future::ready(result.is_ok())), - ), - config.ewma_default_rtt, - config.ewma_decay_time, - NoInstrument, + let peer_set = PeerSet::new( + PeakEwmaDiscover::new( + ServiceStream::new( + // ServiceStream interprets an error as stream termination, + // so discard any errored connections... + peerset_rx.filter(|result| future::ready(result.is_ok())), ), - demand_tx.clone(), + config.ewma_default_rtt, + config.ewma_decay_time, + NoInstrument, ), - config.peerset_request_buffer_size, + demand_tx.clone(), + handle_rx, ); + let peer_set = Buffer::new(peer_set, config.peerset_request_buffer_size); // Connect the tx end to the 3 peer sources: // 1. Initial peers, specified in the config. - tokio::spawn(add_initial_peers( + let add_guard = tokio::spawn(add_initial_peers( config.initial_peers(), connector.clone(), peerset_tx.clone(), )); // 2. Incoming peer connections, via a listener. - tokio::spawn( - listen(config.listen_addr, listener, peerset_tx.clone()).map(|result| { - if let Err(e) = result { - error!(%e); - } - }), - ); + let listen_guard = tokio::spawn(listen(config.listen_addr, listener, peerset_tx.clone())); // 3. Outgoing peers we connect to in response to load. - let mut candidates = CandidateSet::new(address_book.clone(), peer_set.clone()); // We need to await candidates.update() here, because Zcashd only sends one @@ -125,21 +118,18 @@ where let _ = demand_tx.try_send(()); } - tokio::spawn( - crawl_and_dial( - config.new_peer_interval, - demand_tx, - demand_rx, - candidates, - connector, - peerset_tx, - ) - .map(|result| { - if let Err(e) = result { - error!(%e); - } - }), - ); + let crawl_guard = tokio::spawn(crawl_and_dial( + config.new_peer_interval, + demand_tx, + demand_rx, + candidates, + connector, + peerset_tx, + )); + + handle_tx + .send(vec![add_guard, listen_guard, crawl_guard]) + .unwrap(); (peer_set, address_book) } @@ -151,7 +141,8 @@ async fn add_initial_peers( initial_peers: std::collections::HashSet, connector: S, mut tx: mpsc::Sender, -) where +) -> Result<(), BoxedStdError> +where S: Service, Error = BoxedStdError> + Clone, S::Future: Send + 'static, @@ -160,9 +151,12 @@ async fn add_initial_peers( use tower::util::CallAllUnordered; let addr_stream = futures::stream::iter(initial_peers.into_iter()); let mut handshakes = CallAllUnordered::new(connector, addr_stream); + while let Some(handshake_result) = handshakes.next().await { - let _ = tx.send(handshake_result).await; + tx.send(handshake_result).await?; } + + Ok(()) } /// Bind to `addr`, listen for peers using `handshaker`, then send the diff --git a/zebra-network/src/peer_set/set.rs b/zebra-network/src/peer_set/set.rs index eeedb98e..a7e2c49f 100644 --- a/zebra-network/src/peer_set/set.rs +++ b/zebra-network/src/peer_set/set.rs @@ -14,6 +14,8 @@ use futures::{ stream::FuturesUnordered, }; use indexmap::IndexMap; +use tokio::sync::oneshot::error::TryRecvError; +use tokio::task::JoinHandle; use tower::{ discover::{Change, Discover}, Service, @@ -77,6 +79,15 @@ where unready_services: FuturesUnordered>, next_idx: Option, demand_signal: mpsc::Sender<()>, + /// Channel for passing ownership of tokio JoinHandles from PeerSet's background tasks + /// + /// The join handles passed into the PeerSet are used populate the `guards` member + handle_rx: tokio::sync::oneshot::Receiver>>>, + /// Unordered set of handles to background tasks associated with the `PeerSet` + /// + /// These guards are checked for errors as part of `poll_ready` which lets + /// the `PeerSet` propagate errors from background tasks back to the user + guards: futures::stream::FuturesUnordered>>, } impl PeerSet @@ -90,7 +101,11 @@ where ::Metric: Debug, { /// Construct a peerset which uses `discover` internally. - pub fn new(discover: D, demand_signal: mpsc::Sender<()>) -> Self { + pub fn new( + discover: D, + demand_signal: mpsc::Sender<()>, + handle_rx: tokio::sync::oneshot::Receiver>>>, + ) -> Self { Self { discover, ready_services: IndexMap::new(), @@ -98,6 +113,8 @@ where unready_services: FuturesUnordered::new(), next_idx: None, demand_signal, + guards: futures::stream::FuturesUnordered::new(), + handle_rx, } } @@ -152,6 +169,30 @@ where }); } + fn check_for_background_errors(&mut self, cx: &mut Context) -> Result<(), BoxedStdError> { + if self.guards.is_empty() { + match self.handle_rx.try_recv() { + Ok(handles) => { + for handle in handles { + self.guards.push(handle); + } + } + Err(TryRecvError::Closed) => unreachable!( + "try_recv will never be called if the futures have already been received" + ), + Err(TryRecvError::Empty) => return Ok(()), + } + } + + match Pin::new(&mut self.guards).poll_next(cx) { + Poll::Pending => {} + Poll::Ready(Some(res)) => res??, + Poll::Ready(None) => Err("all background tasks have exited")?, + } + + Ok(()) + } + fn poll_unready(&mut self, cx: &mut Context<'_>) { loop { match Pin::new(&mut self.unready_services).poll_next(cx) { @@ -223,6 +264,7 @@ where Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.check_for_background_errors(cx)?; // Process peer discovery updates. let _ = self.poll_discover(cx)?; diff --git a/zebrad/Cargo.toml b/zebrad/Cargo.toml index 1abcb271..6602fe35 100644 --- a/zebrad/Cargo.toml +++ b/zebrad/Cargo.toml @@ -14,7 +14,7 @@ serde = { version = "1", features = ["serde_derive"] } toml = "0.5" thiserror = "1" -tokio = { version = "0.2", features = ["time", "rt-threaded", "stream"] } +tokio = { version = "0.2", features = ["time", "rt-threaded", "stream", "macros"] } futures = "0.3" tracing = "0.1" diff --git a/zebrad/src/commands/connect.rs b/zebrad/src/commands/connect.rs index 79f202f4..db804602 100644 --- a/zebrad/src/commands/connect.rs +++ b/zebrad/src/commands/connect.rs @@ -33,10 +33,17 @@ impl Runnable for ConnectCmd { .rt .take(); - rt.expect("runtime should not already be taken") - .block_on(self.connect()) - // Surface any error that occurred executing the future. - .unwrap(); + let result = rt + .expect("runtime should not already be taken") + .block_on(self.connect()); + + match result { + Ok(()) => {} + Err(e) => { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + } } } diff --git a/zebrad/src/commands/seed.rs b/zebrad/src/commands/seed.rs index 2c0c11a0..c99bf3f7 100644 --- a/zebrad/src/commands/seed.rs +++ b/zebrad/src/commands/seed.rs @@ -138,7 +138,7 @@ impl SeedCmd { let config = app_config().network.clone(); - let (mut peer_set, address_book) = zebra_network::init(config, buffered_svc.clone()).await; + let (mut peer_set, address_book) = zebra_network::init(config, buffered_svc).await; let _ = addressbook_tx.send(address_book); diff --git a/zebrad/src/commands/start.rs b/zebrad/src/commands/start.rs index 35c2b67e..5109e774 100644 --- a/zebrad/src/commands/start.rs +++ b/zebrad/src/commands/start.rs @@ -36,7 +36,7 @@ impl Runnable for StartCmd { let default_config = ZebradConfig::default(); println!("Default config: {:?}", default_config); - println!("Toml:\n{}", toml::to_string(&default_config).unwrap()); + println!("Toml:\n{}", toml::Value::try_from(&default_config).unwrap()); info!("Starting placeholder loop"); diff --git a/zebrad/src/config.rs b/zebrad/src/config.rs index 6b4e502a..0ddc0138 100644 --- a/zebrad/src/config.rs +++ b/zebrad/src/config.rs @@ -53,3 +53,16 @@ impl Default for MetricsSection { } } } + +#[cfg(test)] +mod test { + #[test] + fn test_toml_ser() -> color_eyre::Result<()> { + let default_config = super::ZebradConfig::default(); + println!("Default config: {:?}", default_config); + + println!("Toml:\n{}", toml::Value::try_from(&default_config)?); + + Ok(()) + } +}