fix(panic): Stop panicking on async task cancellation on shutdown in network and state futures (#7219)

* Add an async-error feature and an initial module structure

* Implement checking for panics in OS threads and async tasks

* Implement waiting for panics in OS threads and async tasks

* Add a TODO to simplify some state request error handling

* Use the new panic-checking methods in zebra-state

* Use new panic-checking methods in zebra-network

* fixup! Implement waiting for panics in OS threads and async tasks

* Replace existing async code with generic panic-checking methods

* Simplify trait to a single method

* Move thread panic code into generic trait impls

* Simplify option handling

Co-authored-by: Arya <aryasolhi@gmail.com>

* Fix comment

Co-authored-by: Arya <aryasolhi@gmail.com>

* Add missing track_caller

---------

Co-authored-by: Arya <aryasolhi@gmail.com>
This commit is contained in:
teor 2023-07-18 14:53:26 +10:00 committed by GitHub
parent c885de4abb
commit 3bbe3cec4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 368 additions and 233 deletions

View File

@ -24,6 +24,11 @@ json-conversion = [
"serde_json", "serde_json",
] ]
# Async error handling convenience traits
async-error = [
"tokio",
]
# Experimental mining RPC support # Experimental mining RPC support
getblocktemplate-rpcs = [ getblocktemplate-rpcs = [
"zcash_address", "zcash_address",
@ -39,7 +44,7 @@ proptest-impl = [
"proptest-derive", "proptest-derive",
"rand", "rand",
"rand_chacha", "rand_chacha",
"tokio", "tokio/tracing",
"zebra-test", "zebra-test",
] ]
@ -108,6 +113,9 @@ reddsa = "0.5.0"
# Production feature json-conversion # Production feature json-conversion
serde_json = { version = "1.0.100", optional = true } serde_json = { version = "1.0.100", optional = true }
# Production feature async-error and testing feature proptest-impl
tokio = { version = "1.29.1", optional = true }
# Experimental feature getblocktemplate-rpcs # Experimental feature getblocktemplate-rpcs
zcash_address = { version = "0.3.0", optional = true } zcash_address = { version = "0.3.0", optional = true }
@ -118,8 +126,6 @@ proptest-derive = { version = "0.3.0", optional = true }
rand = { version = "0.8.5", optional = true } rand = { version = "0.8.5", optional = true }
rand_chacha = { version = "0.3.1", optional = true } rand_chacha = { version = "0.3.1", optional = true }
tokio = { version = "1.29.1", features = ["tracing"], optional = true }
zebra-test = { path = "../zebra-test/", version = "1.0.0-beta.27", optional = true } zebra-test = { path = "../zebra-test/", version = "1.0.0-beta.27", optional = true }
[dev-dependencies] [dev-dependencies]

View File

@ -1,6 +1,15 @@
//! Tracing the execution time of functions. //! Diagnostic types and functions for Zebra:
//! //! - code performance
//! TODO: also trace polling time for futures, using a `Future` wrapper //! - task handling
//! - errors and panics
pub mod task;
// Tracing the execution time of functions.
//
// TODO:
// - move this to a `timing` submodule
// - also trace polling time for futures, using a `Future` wrapper
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};

View File

@ -0,0 +1,47 @@
//! Diagnostic types and functions for Zebra tasks:
//! - OS thread handling
//! - async future task handling
//! - errors and panics
#[cfg(feature = "async-error")]
pub mod future;
pub mod thread;
/// A trait that checks a task's return value for panics.
pub trait CheckForPanics {
/// The output type, after removing panics from `Self`.
type Output;
/// Check if `self` contains a panic payload or an unexpected termination, then panic.
/// Otherwise, return the non-panic part of `self`.
///
/// # Panics
///
/// If `self` contains a panic payload or an unexpected termination.
#[track_caller]
fn check_for_panics(self) -> Self::Output;
}
/// A trait that waits for a task to finish, then handles panics and cancellations.
pub trait WaitForPanics {
/// The underlying task output, after removing panics and unwrapping termination results.
type Output;
/// Waits for `self` to finish, then check if its output is:
/// - a panic payload: resume that panic,
/// - an unexpected termination: panic with that error,
/// - an expected termination: hang waiting for shutdown.
///
/// Otherwise, returns the task return value of `self`.
///
/// # Panics
///
/// If `self` contains a panic payload or an unexpected termination.
///
/// # Hangs
///
/// If `self` contains an expected termination, and we're shutting down anyway.
#[track_caller]
fn wait_for_panics(self) -> Self::Output;
}

View File

@ -0,0 +1,93 @@
//! Diagnostic types and functions for Zebra async future tasks:
//! - task handles
//! - errors and panics
use std::{future, panic};
use futures::future::{BoxFuture, FutureExt};
use tokio::task::{JoinError, JoinHandle};
use crate::shutdown::is_shutting_down;
use super::{CheckForPanics, WaitForPanics};
/// This is the return type of the [`JoinHandle`] future.
impl<T> CheckForPanics for Result<T, JoinError> {
/// The [`JoinHandle`]'s task output, after resuming any panics,
/// and ignoring task cancellations on shutdown.
type Output = Result<T, JoinError>;
/// Returns the task result if the task finished normally.
/// Otherwise, resumes any panics, logs unexpected errors, and ignores any expected errors.
///
/// If the task finished normally, returns `Some(T)`.
/// If the task was cancelled, returns `None`.
#[track_caller]
fn check_for_panics(self) -> Self::Output {
match self {
Ok(task_output) => Ok(task_output),
Err(join_error) => Err(join_error.check_for_panics()),
}
}
}
impl CheckForPanics for JoinError {
/// The [`JoinError`] after resuming any panics, and logging any unexpected task cancellations.
type Output = JoinError;
/// Resume any panics and panic on unexpected task cancellations.
/// Always returns [`JoinError::Cancelled`](JoinError::is_cancelled).
#[track_caller]
fn check_for_panics(self) -> Self::Output {
match self.try_into_panic() {
Ok(panic_payload) => panic::resume_unwind(panic_payload),
// We could ignore this error, but then we'd have to change the return type.
Err(task_cancelled) if is_shutting_down() => {
debug!(
?task_cancelled,
"ignoring cancelled task because Zebra is shutting down"
);
task_cancelled
}
Err(task_cancelled) => {
panic!("task cancelled during normal Zebra operation: {task_cancelled:?}");
}
}
}
}
impl<T> WaitForPanics for JoinHandle<T>
where
T: Send + 'static,
{
type Output = BoxFuture<'static, T>;
/// Returns a future which waits for `self` to finish, then checks if its output is:
/// - a panic payload: resume that panic,
/// - an unexpected termination: panic with that error,
/// - an expected termination: hang waiting for shutdown.
///
/// Otherwise, returns the task return value of `self`.
///
/// # Panics
///
/// If `self` contains a panic payload, or [`JoinHandle::abort()`] has been called on `self`.
///
/// # Hangs
///
/// If `self` contains an expected termination, and we're shutting down anyway.
/// Futures hang by returning `Pending` and not setting a waker, so this uses minimal resources.
#[track_caller]
fn wait_for_panics(self) -> Self::Output {
async move {
match self.await.check_for_panics() {
Ok(task_output) => task_output,
Err(_expected_cancel_error) => future::pending().await,
}
}
.boxed()
}
}

View File

@ -0,0 +1,108 @@
//! Diagnostic types and functions for Zebra OS thread tasks:
//! - task handles
//! - errors and panics
use std::{
panic,
sync::Arc,
thread::{self, JoinHandle},
};
use super::{CheckForPanics, WaitForPanics};
impl<T> CheckForPanics for thread::Result<T> {
type Output = T;
/// Panics if the thread panicked.
///
/// Threads can't be cancelled except by using a panic, so there are no thread errors here.
#[track_caller]
fn check_for_panics(self) -> Self::Output {
match self {
// The value returned by the thread when it finished.
Ok(thread_output) => thread_output,
// A thread error is always a panic.
Err(panic_payload) => panic::resume_unwind(panic_payload),
}
}
}
impl<T> WaitForPanics for JoinHandle<T> {
type Output = T;
/// Waits for the thread to finish, then panics if the thread panicked.
#[track_caller]
fn wait_for_panics(self) -> Self::Output {
self.join().check_for_panics()
}
}
impl<T> WaitForPanics for Arc<JoinHandle<T>> {
type Output = Option<T>;
/// If this is the final `Arc`, waits for the thread to finish, then panics if the thread
/// panicked. Otherwise, returns the thread's return value.
///
/// If this is not the final `Arc`, drops the handle and immediately returns `None`.
#[track_caller]
fn wait_for_panics(self) -> Self::Output {
// If we are the last Arc with a reference to this handle,
// we can wait for it and propagate any panics.
//
// We use into_inner() because it guarantees that exactly one of the tasks gets the
// JoinHandle. try_unwrap() lets us keep the JoinHandle, but it can also miss panics.
//
// This is more readable as an expanded statement.
#[allow(clippy::manual_map)]
if let Some(handle) = Arc::into_inner(self) {
Some(handle.wait_for_panics())
} else {
None
}
}
}
impl<T> CheckForPanics for &mut Option<Arc<JoinHandle<T>>> {
type Output = Option<T>;
/// If this is the final `Arc`, checks if the thread has finished, then panics if the thread
/// panicked. Otherwise, returns the thread's return value.
///
/// If the thread has not finished, or this is not the final `Arc`, returns `None`.
#[track_caller]
fn check_for_panics(self) -> Self::Output {
let handle = self.take()?;
if handle.is_finished() {
// This is the same as calling `self.wait_for_panics()`, but we can't do that,
// because we've taken `self`.
#[allow(clippy::manual_map)]
return handle.wait_for_panics();
}
*self = Some(handle);
None
}
}
impl<T> WaitForPanics for &mut Option<Arc<JoinHandle<T>>> {
type Output = Option<T>;
/// If this is the final `Arc`, waits for the thread to finish, then panics if the thread
/// panicked. Otherwise, returns the thread's return value.
///
/// If this is not the final `Arc`, drops the handle and returns `None`.
#[track_caller]
fn wait_for_panics(self) -> Self::Output {
// This is more readable as an expanded statement.
#[allow(clippy::manual_map)]
if let Some(output) = self.take()?.wait_for_panics() {
Some(output)
} else {
// Some other task has a reference, so we should give up ours to let them use it.
None
}
}
}

View File

@ -83,7 +83,7 @@ howudoin = { version = "0.1.2", optional = true }
proptest = { version = "1.2.0", optional = true } proptest = { version = "1.2.0", optional = true }
proptest-derive = { version = "0.3.0", optional = true } proptest-derive = { version = "0.3.0", optional = true }
zebra-chain = { path = "../zebra-chain", version = "1.0.0-beta.27" } zebra-chain = { path = "../zebra-chain", version = "1.0.0-beta.27", features = ["async-error"] }
[dev-dependencies] [dev-dependencies]
proptest = "1.2.0" proptest = "1.2.0"

View File

@ -8,7 +8,7 @@ use tokio::time::{sleep_until, timeout, Instant};
use tower::{Service, ServiceExt}; use tower::{Service, ServiceExt};
use tracing::Span; use tracing::Span;
use zebra_chain::serialization::DateTime32; use zebra_chain::{diagnostic::task::WaitForPanics, serialization::DateTime32};
use crate::{ use crate::{
constants, meta_addr::MetaAddrChange, peer_set::set::MorePeers, types::MetaAddr, AddressBook, constants, meta_addr::MetaAddrChange, peer_set::set::MorePeers, types::MetaAddr, AddressBook,
@ -348,8 +348,8 @@ where
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
span.in_scope(|| address_book.lock().unwrap().extend(addrs)) span.in_scope(|| address_book.lock().unwrap().extend(addrs))
}) })
.wait_for_panics()
.await .await
.expect("panic in new peers address book update task");
} }
/// Returns the next candidate for a connection attempt, if any are available. /// Returns the next candidate for a connection attempt, if any are available.
@ -403,8 +403,8 @@ where
// Correctness: Spawn address book accesses on a blocking thread, to avoid deadlocks (see #1976). // Correctness: Spawn address book accesses on a blocking thread, to avoid deadlocks (see #1976).
let span = Span::current(); let span = Span::current();
let next_peer = tokio::task::spawn_blocking(move || span.in_scope(next_peer)) let next_peer = tokio::task::spawn_blocking(move || span.in_scope(next_peer))
.await .wait_for_panics()
.expect("panic in next peer address book task")?; .await?;
// Security: rate-limit new outbound peer connections // Security: rate-limit new outbound peer connections
sleep_until(self.min_next_handshake).await; sleep_until(self.min_next_handshake).await;

View File

@ -23,8 +23,7 @@ use rand::seq::SliceRandom;
use tokio::{ use tokio::{
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
sync::broadcast, sync::broadcast,
task::JoinError, time::{sleep, Instant},
time::{error::Elapsed, sleep, Instant},
}; };
use tokio_stream::wrappers::IntervalStream; use tokio_stream::wrappers::IntervalStream;
use tower::{ use tower::{
@ -33,11 +32,11 @@ use tower::{
use tracing::Span; use tracing::Span;
use tracing_futures::Instrument; use tracing_futures::Instrument;
use zebra_chain::chain_tip::ChainTip; use zebra_chain::{chain_tip::ChainTip, diagnostic::task::WaitForPanics};
use crate::{ use crate::{
address_book_updater::AddressBookUpdater, address_book_updater::AddressBookUpdater,
constants::{self, HANDSHAKE_TIMEOUT}, constants,
meta_addr::{MetaAddr, MetaAddrChange}, meta_addr::{MetaAddr, MetaAddrChange},
peer::{ peer::{
self, address_is_valid_for_inbound_listeners, HandshakeRequest, MinimumPeerVersion, self, address_is_valid_for_inbound_listeners, HandshakeRequest, MinimumPeerVersion,
@ -207,18 +206,8 @@ where
// Wait for the initial seed peer count // Wait for the initial seed peer count
let mut active_outbound_connections = initial_peers_join let mut active_outbound_connections = initial_peers_join
.wait_for_panics()
.await .await
.unwrap_or_else(|e @ JoinError { .. }| {
if e.is_panic() {
panic!("panic in initial peer connections task: {e:?}");
} else {
info!(
"task error during initial peer connections: {e:?},\
is Zebra shutting down?"
);
Err(e.into())
}
})
.expect("unexpected error connecting to initial peers"); .expect("unexpected error connecting to initial peers");
let active_initial_peer_count = active_outbound_connections.update_count(); let active_initial_peer_count = active_outbound_connections.update_count();
@ -354,22 +343,11 @@ where
} }
.in_current_span(), .in_current_span(),
) )
.wait_for_panics()
}) })
.collect(); .collect();
while let Some(handshake_result) = handshakes.next().await { while let Some(handshake_result) = handshakes.next().await {
let handshake_result = handshake_result.unwrap_or_else(|e @ JoinError { .. }| {
if e.is_panic() {
panic!("panic in initial peer connection: {e:?}");
} else {
info!(
"task error during initial peer connection: {e:?},\
is Zebra shutting down?"
);
// Fake the address, it doesn't matter because we're shutting down anyway
Err((PeerSocketAddr::unspecified(), e.into()))
}
});
match handshake_result { match handshake_result {
Ok(change) => { Ok(change) => {
handshake_success_total += 1; handshake_success_total += 1;
@ -637,36 +615,9 @@ where
peerset_tx.clone(), peerset_tx.clone(),
) )
.await? .await?
.map(|res| match res { .wait_for_panics();
Ok(()) => (),
Err(e @ JoinError { .. }) => {
if e.is_panic() {
panic!("panic during inbound handshaking: {e:?}");
} else {
info!(
"task error during inbound handshaking: {e:?}, is Zebra shutting down?"
)
}
}
});
let handshake_timeout = tokio::time::timeout( handshakes.push(handshake_task);
// Only trigger this timeout if the inner handshake timeout fails
HANDSHAKE_TIMEOUT + Duration::from_millis(500),
handshake_task,
)
.map(|res| match res {
Ok(()) => (),
Err(_e @ Elapsed { .. }) => {
info!(
"timeout in spawned accept_inbound_handshake() task: \
inner task should have timed out already"
);
}
});
// This timeout helps locate inbound peer connection hangs, see #6763 for details.
handshakes.push(Box::pin(handshake_timeout));
// Rate-limit inbound connection handshakes. // Rate-limit inbound connection handshakes.
// But sleep longer after a successful connection, // But sleep longer after a successful connection,
@ -918,80 +869,64 @@ where
// Spawn each handshake or crawl into an independent task, so handshakes can make // Spawn each handshake or crawl into an independent task, so handshakes can make
// progress while crawls are running. // progress while crawls are running.
let handshake_or_crawl_handle = tokio::spawn(async move { let handshake_or_crawl_handle = tokio::spawn(
// Try to get the next available peer for a handshake. async move {
// // Try to get the next available peer for a handshake.
// candidates.next() has a short timeout, and briefly holds the address //
// book lock, so it shouldn't hang. // candidates.next() has a short timeout, and briefly holds the address
// // book lock, so it shouldn't hang.
// Hold the lock for as short a time as possible. //
let candidate = { candidates.lock().await.next().await }; // Hold the lock for as short a time as possible.
let candidate = { candidates.lock().await.next().await };
if let Some(candidate) = candidate { if let Some(candidate) = candidate {
// we don't need to spawn here, because there's nothing running concurrently // we don't need to spawn here, because there's nothing running concurrently
dial( dial(
candidate, candidate,
outbound_connector, outbound_connector,
outbound_connection_tracker, outbound_connection_tracker,
peerset_tx, peerset_tx,
address_book, address_book,
demand_tx, demand_tx,
) )
.await?; .await?;
Ok(HandshakeFinished) Ok(HandshakeFinished)
} else {
// There weren't any peers, so try to get more peers.
debug!("demand for peers but no available candidates");
crawl(candidates, demand_tx).await?;
Ok(DemandCrawlFinished)
}
}.in_current_span())
.map(|res| match res {
Ok(crawler_action) => crawler_action,
Err(e @ JoinError {..}) => {
if e.is_panic() {
panic!("panic during outbound handshake: {e:?}");
} else { } else {
info!("task error during outbound handshake: {e:?}, is Zebra shutting down?") // There weren't any peers, so try to get more peers.
} debug!("demand for peers but no available candidates");
// Just fake it
Ok(HandshakeFinished)
}
});
handshakes.push(Box::pin(handshake_or_crawl_handle)); crawl(candidates, demand_tx).await?;
Ok(DemandCrawlFinished)
}
}
.in_current_span(),
)
.wait_for_panics();
handshakes.push(handshake_or_crawl_handle);
} }
Ok(TimerCrawl { tick }) => { Ok(TimerCrawl { tick }) => {
let candidates = candidates.clone(); let candidates = candidates.clone();
let demand_tx = demand_tx.clone(); let demand_tx = demand_tx.clone();
let crawl_handle = tokio::spawn(async move { let crawl_handle = tokio::spawn(
debug!( async move {
?tick, debug!(
"crawling for more peers in response to the crawl timer" ?tick,
); "crawling for more peers in response to the crawl timer"
);
crawl(candidates, demand_tx).await?; crawl(candidates, demand_tx).await?;
Ok(TimerCrawlFinished)
}.in_current_span())
.map(move |res| match res {
Ok(crawler_action) => crawler_action,
Err(e @ JoinError {..}) => {
if e.is_panic() {
panic!("panic during outbound TimerCrawl: {tick:?} {e:?}");
} else {
info!("task error during outbound TimerCrawl: {e:?}, is Zebra shutting down?")
}
// Just fake it
Ok(TimerCrawlFinished) Ok(TimerCrawlFinished)
} }
}); .in_current_span(),
)
.wait_for_panics();
handshakes.push(Box::pin(crawl_handle)); handshakes.push(crawl_handle);
} }
// Completed spawned tasks // Completed spawned tasks
@ -1162,27 +1097,16 @@ async fn report_failed(address_book: Arc<std::sync::Mutex<AddressBook>>, addr: M
// //
// Spawn address book accesses on a blocking thread, to avoid deadlocks (see #1976). // Spawn address book accesses on a blocking thread, to avoid deadlocks (see #1976).
let span = Span::current(); let span = Span::current();
let task_result = tokio::task::spawn_blocking(move || { let updated_addr = tokio::task::spawn_blocking(move || {
span.in_scope(|| address_book.lock().unwrap().update(addr)) span.in_scope(|| address_book.lock().unwrap().update(addr))
}) })
.wait_for_panics()
.await; .await;
match task_result { assert_eq!(
Ok(updated_addr) => assert_eq!( updated_addr.map(|addr| addr.addr()),
updated_addr.map(|addr| addr.addr()), Some(addr.addr()),
Some(addr.addr()), "incorrect address updated by address book: \
"incorrect address updated by address book: \ original: {addr:?}, updated: {updated_addr:?}"
original: {addr:?}, updated: {updated_addr:?}" );
),
Err(e @ JoinError { .. }) => {
if e.is_panic() {
panic!("panic in peer failure address book update task: {e:?}");
} else {
info!(
"task error during peer failure address book update task: {e:?},\
is Zebra shutting down?"
)
}
}
}
} }

View File

@ -71,7 +71,7 @@ tracing = "0.1.37"
elasticsearch = { version = "8.5.0-alpha.1", default-features = false, features = ["rustls-tls"], optional = true } elasticsearch = { version = "8.5.0-alpha.1", default-features = false, features = ["rustls-tls"], optional = true }
serde_json = { version = "1.0.100", package = "serde_json", optional = true } serde_json = { version = "1.0.100", package = "serde_json", optional = true }
zebra-chain = { path = "../zebra-chain", version = "1.0.0-beta.27" } zebra-chain = { path = "../zebra-chain", version = "1.0.0-beta.27", features = ["async-error"] }
# prod feature progress-bar # prod feature progress-bar
howudoin = { version = "0.1.2", optional = true } howudoin = { version = "0.1.2", optional = true }

View File

@ -32,6 +32,9 @@ pub enum Response {
Depth(Option<u32>), Depth(Option<u32>),
/// Response to [`Request::Tip`] with the current best chain tip. /// Response to [`Request::Tip`] with the current best chain tip.
//
// TODO: remove this request, and replace it with a call to
// `LatestChainTip::best_tip_height_and_hash()`
Tip(Option<(block::Height, block::Hash)>), Tip(Option<(block::Height, block::Hash)>),
/// Response to [`Request::BlockLocator`] with a block locator object. /// Response to [`Request::BlockLocator`] with a block locator object.

View File

@ -43,7 +43,7 @@ use tower::buffer::Buffer;
use zebra_chain::{ use zebra_chain::{
block::{self, CountedHeader, HeightDiff}, block::{self, CountedHeader, HeightDiff},
diagnostic::CodeTimer, diagnostic::{task::WaitForPanics, CodeTimer},
parameters::{Network, NetworkUpgrade}, parameters::{Network, NetworkUpgrade},
}; };
@ -1209,8 +1209,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::Tip(tip)) Ok(ReadResponse::Tip(tip))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::Tip")) .wait_for_panics()
.boxed()
} }
// Used by the StateService. // Used by the StateService.
@ -1231,8 +1230,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::Depth(depth)) Ok(ReadResponse::Depth(depth))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::Depth")) .wait_for_panics()
.boxed()
} }
// Used by the StateService. // Used by the StateService.
@ -1255,10 +1253,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::BestChainNextMedianTimePast(median_time_past?)) Ok(ReadResponse::BestChainNextMedianTimePast(median_time_past?))
}) })
}) })
.map(|join_result| { .wait_for_panics()
join_result.expect("panic in ReadRequest::BestChainNextMedianTimePast")
})
.boxed()
} }
// Used by the get_block (raw) RPC and the StateService. // Used by the get_block (raw) RPC and the StateService.
@ -1283,8 +1278,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::Block(block)) Ok(ReadResponse::Block(block))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::Block")) .wait_for_panics()
.boxed()
} }
// For the get_raw_transaction RPC and the StateService. // For the get_raw_transaction RPC and the StateService.
@ -1302,8 +1296,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::Transaction(response)) Ok(ReadResponse::Transaction(response))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::Transaction")) .wait_for_panics()
.boxed()
} }
// Used by the getblock (verbose) RPC. // Used by the getblock (verbose) RPC.
@ -1332,10 +1325,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::TransactionIdsForBlock(transaction_ids)) Ok(ReadResponse::TransactionIdsForBlock(transaction_ids))
}) })
}) })
.map(|join_result| { .wait_for_panics()
join_result.expect("panic in ReadRequest::TransactionIdsForBlock")
})
.boxed()
} }
ReadRequest::UnspentBestChainUtxo(outpoint) => { ReadRequest::UnspentBestChainUtxo(outpoint) => {
@ -1359,8 +1349,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::UnspentBestChainUtxo(utxo)) Ok(ReadResponse::UnspentBestChainUtxo(utxo))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::UnspentBestChainUtxo")) .wait_for_panics()
.boxed()
} }
// Manually used by the StateService to implement part of AwaitUtxo. // Manually used by the StateService to implement part of AwaitUtxo.
@ -1381,8 +1370,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::AnyChainUtxo(utxo)) Ok(ReadResponse::AnyChainUtxo(utxo))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::AnyChainUtxo")) .wait_for_panics()
.boxed()
} }
// Used by the StateService. // Used by the StateService.
@ -1405,8 +1393,7 @@ impl Service<ReadRequest> for ReadStateService {
)) ))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::BlockLocator")) .wait_for_panics()
.boxed()
} }
// Used by the StateService. // Used by the StateService.
@ -1433,8 +1420,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::BlockHashes(block_hashes)) Ok(ReadResponse::BlockHashes(block_hashes))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::FindBlockHashes")) .wait_for_panics()
.boxed()
} }
// Used by the StateService. // Used by the StateService.
@ -1466,8 +1452,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::BlockHeaders(block_headers)) Ok(ReadResponse::BlockHeaders(block_headers))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::FindBlockHeaders")) .wait_for_panics()
.boxed()
} }
ReadRequest::SaplingTree(hash_or_height) => { ReadRequest::SaplingTree(hash_or_height) => {
@ -1491,8 +1476,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::SaplingTree(sapling_tree)) Ok(ReadResponse::SaplingTree(sapling_tree))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::SaplingTree")) .wait_for_panics()
.boxed()
} }
ReadRequest::OrchardTree(hash_or_height) => { ReadRequest::OrchardTree(hash_or_height) => {
@ -1516,8 +1500,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::OrchardTree(orchard_tree)) Ok(ReadResponse::OrchardTree(orchard_tree))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::OrchardTree")) .wait_for_panics()
.boxed()
} }
// For the get_address_balance RPC. // For the get_address_balance RPC.
@ -1542,8 +1525,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::AddressBalance(balance)) Ok(ReadResponse::AddressBalance(balance))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::AddressBalance")) .wait_for_panics()
.boxed()
} }
// For the get_address_tx_ids RPC. // For the get_address_tx_ids RPC.
@ -1576,10 +1558,7 @@ impl Service<ReadRequest> for ReadStateService {
tx_ids.map(ReadResponse::AddressesTransactionIds) tx_ids.map(ReadResponse::AddressesTransactionIds)
}) })
}) })
.map(|join_result| { .wait_for_panics()
join_result.expect("panic in ReadRequest::TransactionIdsByAddresses")
})
.boxed()
} }
// For the get_address_utxos RPC. // For the get_address_utxos RPC.
@ -1605,8 +1584,7 @@ impl Service<ReadRequest> for ReadStateService {
utxos.map(ReadResponse::AddressUtxos) utxos.map(ReadResponse::AddressUtxos)
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::UtxosByAddresses")) .wait_for_panics()
.boxed()
} }
ReadRequest::CheckBestChainTipNullifiersAndAnchors(unmined_tx) => { ReadRequest::CheckBestChainTipNullifiersAndAnchors(unmined_tx) => {
@ -1639,11 +1617,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::ValidBestChainTipNullifiersAndAnchors) Ok(ReadResponse::ValidBestChainTipNullifiersAndAnchors)
}) })
}) })
.map(|join_result| { .wait_for_panics()
join_result
.expect("panic in ReadRequest::CheckBestChainTipNullifiersAndAnchors")
})
.boxed()
} }
// Used by the get_block and get_block_hash RPCs. // Used by the get_block and get_block_hash RPCs.
@ -1672,8 +1646,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::BlockHash(hash)) Ok(ReadResponse::BlockHash(hash))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::BestChainBlockHash")) .wait_for_panics()
.boxed()
} }
// Used by get_block_template RPC. // Used by get_block_template RPC.
@ -1712,8 +1685,7 @@ impl Service<ReadRequest> for ReadStateService {
get_block_template_info.map(ReadResponse::ChainInfo) get_block_template_info.map(ReadResponse::ChainInfo)
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::ChainInfo")) .wait_for_panics()
.boxed()
} }
// Used by getmininginfo, getnetworksolps, and getnetworkhashps RPCs. // Used by getmininginfo, getnetworksolps, and getnetworkhashps RPCs.
@ -1766,8 +1738,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::SolutionRate(solution_rate)) Ok(ReadResponse::SolutionRate(solution_rate))
}) })
}) })
.map(|join_result| join_result.expect("panic in ReadRequest::SolutionRate")) .wait_for_panics()
.boxed()
} }
#[cfg(feature = "getblocktemplate-rpcs")] #[cfg(feature = "getblocktemplate-rpcs")]
@ -1815,10 +1786,7 @@ impl Service<ReadRequest> for ReadStateService {
Ok(ReadResponse::ValidBlockProposal) Ok(ReadResponse::ValidBlockProposal)
}) })
}) })
.map(|join_result| { .wait_for_panics()
join_result.expect("panic in ReadRequest::CheckBlockProposalValidity")
})
.boxed()
} }
} }
} }

View File

@ -2,7 +2,6 @@
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
panic,
sync::{mpsc, Arc}, sync::{mpsc, Arc},
thread::{self, JoinHandle}, thread::{self, JoinHandle},
}; };
@ -10,7 +9,11 @@ use std::{
use semver::Version; use semver::Version;
use tracing::Span; use tracing::Span;
use zebra_chain::{block::Height, parameters::Network}; use zebra_chain::{
block::Height,
diagnostic::task::{CheckForPanics, WaitForPanics},
parameters::Network,
};
use DbFormatChange::*; use DbFormatChange::*;
@ -482,42 +485,16 @@ impl DbFormatChangeThreadHandle {
/// ///
/// This method should be called regularly, so that panics are detected as soon as possible. /// This method should be called regularly, so that panics are detected as soon as possible.
pub fn check_for_panics(&mut self) { pub fn check_for_panics(&mut self) {
let update_task = self.update_task.take(); self.update_task.check_for_panics();
if let Some(update_task) = update_task {
if update_task.is_finished() {
// We use into_inner() because it guarantees that exactly one of the tasks
// gets the JoinHandle. try_unwrap() lets us keep the JoinHandle, but it can also
// miss panics.
if let Some(update_task) = Arc::into_inner(update_task) {
// We are the last handle with a reference to this task,
// so we can propagate any panics
if let Err(thread_panic) = update_task.join() {
panic::resume_unwind(thread_panic);
}
}
} else {
// It hasn't finished, so we need to put it back
self.update_task = Some(update_task);
}
}
} }
/// Wait for the spawned thread to finish. If it exited with a panic, resume that panic. /// Wait for the spawned thread to finish. If it exited with a panic, resume that panic.
/// ///
/// Exits early if the thread has other outstanding handles.
///
/// This method should be called during shutdown. /// This method should be called during shutdown.
pub fn wait_for_panics(&mut self) { pub fn wait_for_panics(&mut self) {
if let Some(update_task) = self.update_task.take() { self.update_task.wait_for_panics();
// We use into_inner() because it guarantees that exactly one of the tasks
// gets the JoinHandle. See the comments in check_for_panics().
if let Some(update_task) = Arc::into_inner(update_task) {
// We are the last handle with a reference to this task,
// so we can propagate any panics
if let Err(thread_panic) = update_task.join() {
panic::resume_unwind(thread_panic);
}
}
}
} }
} }