diff --git a/zebra-state/src/service/non_finalized_state.rs b/zebra-state/src/service/non_finalized_state.rs index b6ad1fad..034e5484 100644 --- a/zebra-state/src/service/non_finalized_state.rs +++ b/zebra-state/src/service/non_finalized_state.rs @@ -37,7 +37,7 @@ pub struct NonFinalizedState { /// Verified, non-finalized chains, in ascending order. /// /// The best chain is `chain_set.last()` or `chain_set.iter().next_back()`. - pub chain_set: BTreeSet>, + pub chain_set: BTreeSet>, /// The configured Zcash network. // @@ -89,11 +89,14 @@ impl NonFinalizedState { // extract best chain let mut best_chain = chains.next_back().expect("there's at least one chain"); + // clone if required + let write_best_chain = Arc::make_mut(&mut best_chain); + // extract the rest into side_chains so they can be mutated let side_chains = chains; // remove the lowest height block from the best_chain to be finalized - let finalizing = best_chain.pop_root(); + let finalizing = write_best_chain.pop_root(); // add best_chain back to `self.chain_set` if !best_chain.is_empty() { @@ -102,16 +105,25 @@ impl NonFinalizedState { // for each remaining chain in side_chains for mut chain in side_chains { - // remove the first block from `chain` - let chain_start = chain.pop_root(); - // if block equals finalized_block - if !chain.is_empty() && chain_start.hash == finalizing.hash { - // add the chain back to `self.chain_set` - self.chain_set.insert(chain); - } else { - // else discard `chain` + if chain.non_finalized_root_hash() != finalizing.hash { + // If we popped the root, the chain would be empty or orphaned, + // so just drop it now. drop(chain); + + continue; } + + // otherwise, the popped root block is the same as the finalizing block + + // clone if required + let write_chain = Arc::make_mut(&mut chain); + + // remove the first block from `chain` + let chain_start = write_chain.pop_root(); + assert_eq!(chain_start.hash, finalizing.hash); + + // add the chain back to `self.chain_set` + self.chain_set.insert(chain); } self.update_metrics_for_chains(); @@ -139,26 +151,21 @@ impl NonFinalizedState { finalized_state.history_tree(), )?; - // We might have taken a chain, so all validation must happen within - // validate_and_commit, so that the chain is restored correctly. - match self.validate_and_commit(*parent_chain.clone(), prepared, finalized_state) { - Ok(child_chain) => { - // if the block is valid, keep the child chain, and drop the parent chain - self.chain_set.insert(Box::new(child_chain)); - self.update_metrics_for_committed_block(height, hash); - Ok(()) - } - Err(err) => { - // if the block is invalid, restore the unmodified parent chain - // (the child chain might have been modified before the error) - // - // If the chain was forked, this adds an extra chain to the - // chain set. This extra chain will eventually get deleted - // (or re-used for a valid fork). - self.chain_set.insert(parent_chain); - Err(err) - } - } + // If the block is invalid, return the error, + // and drop the cloned parent Arc, or newly created chain fork. + let modified_chain = self.validate_and_commit(parent_chain, prepared, finalized_state)?; + + // If the block is valid: + // - add the new chain fork or updated chain to the set of recent chains + // - remove the parent chain, if it was in the chain set + // (if it was a newly created fork, it won't be in the chain set) + self.chain_set.insert(modified_chain); + self.chain_set + .retain(|chain| chain.non_finalized_tip_hash() != parent_hash); + + self.update_metrics_for_committed_block(height, hash); + + Ok(()) } /// Commit block to the non-finalized state as a new chain where its parent @@ -179,38 +186,44 @@ impl NonFinalizedState { ); let (height, hash) = (prepared.height, prepared.hash); - // if the block is invalid, drop the newly created chain fork - let chain = self.validate_and_commit(chain, prepared, finalized_state)?; - self.chain_set.insert(Box::new(chain)); + // If the block is invalid, return the error, and drop the newly created chain fork + let chain = self.validate_and_commit(Arc::new(chain), prepared, finalized_state)?; + + // If the block is valid, add the new chain fork to the set of recent chains. + self.chain_set.insert(chain); self.update_metrics_for_committed_block(height, hash); + Ok(()) } /// Contextually validate `prepared` using `finalized_state`. - /// If validation succeeds, push `prepared` onto `parent_chain`. - #[tracing::instrument(level = "debug", skip(self, finalized_state, parent_chain))] + /// If validation succeeds, push `prepared` onto `new_chain`. + /// + /// `new_chain` should start as a clone of the parent chain fork, + /// or the finalized tip. + #[tracing::instrument(level = "debug", skip(self, finalized_state, new_chain))] fn validate_and_commit( &self, - parent_chain: Chain, + mut new_chain: Arc, prepared: PreparedBlock, finalized_state: &FinalizedState, - ) -> Result { + ) -> Result, ValidateContextError> { let spent_utxos = check::utxo::transparent_spend( &prepared, - &parent_chain.unspent_utxos(), - &parent_chain.spent_utxos, + &new_chain.unspent_utxos(), + &new_chain.spent_utxos, finalized_state, )?; check::prepared_block_commitment_is_valid_for_chain_history( &prepared, self.network, - &parent_chain.history_tree, + &new_chain.history_tree, )?; check::anchors::anchors_refer_to_earlier_treestates( finalized_state, - &parent_chain, + &new_chain, &prepared, )?; @@ -228,7 +241,11 @@ impl NonFinalizedState { } })?; - parent_chain.push(contextual) + // We're pretty sure the new block is valid, + // so clone the inner chain if needed, then add the new block. + Arc::make_mut(&mut new_chain).push(contextual)?; + + Ok(new_chain) } /// Returns the length of the non-finalized portion of the current best chain. @@ -248,34 +265,16 @@ impl NonFinalizedState { .any(|chain| chain.height_by_hash.contains_key(hash)) } - /// Remove and return the first chain satisfying the given predicate. - fn take_chain_if(&mut self, predicate: F) -> Option> + /// Removes and returns the first chain satisfying the given predicate. + /// + /// If multiple chains satisfy the predicate, returns the chain with the highest difficulty. + /// (Using the tip block hash tie-breaker.) + fn find_chain

(&mut self, mut predicate: P) -> Option<&Arc> where - F: Fn(&Chain) -> bool, + P: FnMut(&Chain) -> bool, { - // Chain::cmp uses the partial cumulative work, and the hash of the tip block. - // Neither of these fields has interior mutability. - #[allow(clippy::mutable_key_type)] - let chains = mem::take(&mut self.chain_set); - let mut best_chain_iter = chains.into_iter().rev(); - - while let Some(next_best_chain) = best_chain_iter.next() { - // if the predicate says we should remove it - if predicate(&next_best_chain) { - // add back the remaining chains - for remaining_chain in best_chain_iter { - self.chain_set.insert(remaining_chain); - } - - // and return the chain - return Some(next_best_chain); - } else { - // add the chain back to the set and continue - self.chain_set.insert(next_best_chain); - } - } - - None + // Reverse the iteration order, to find highest difficulty chains first. + self.chain_set.iter().rev().find(|chain| predicate(chain)) } /// Returns the `transparent::Output` pointed to by the given @@ -410,13 +409,16 @@ impl NonFinalizedState { sapling_note_commitment_tree: sapling::tree::NoteCommitmentTree, orchard_note_commitment_tree: orchard::tree::NoteCommitmentTree, history_tree: HistoryTree, - ) -> Result, ValidateContextError> { - match self.take_chain_if(|chain| chain.non_finalized_tip_hash() == parent_hash) { - // An existing chain in the non-finalized state - Some(chain) => Ok(chain), + ) -> Result, ValidateContextError> { + match self.find_chain(|chain| chain.non_finalized_tip_hash() == parent_hash) { + // Clone the existing Arc in the non-finalized state + Some(chain) => Ok(chain.clone()), // Create a new fork - None => Ok(Box::new( - self.chain_set + None => { + // Check the lowest difficulty chains first, + // because the fork could be closer to their tip. + let fork_chain = self + .chain_set .iter() .find_map(|chain| { chain @@ -431,8 +433,10 @@ impl NonFinalizedState { }) .expect( "commit_block is only called with blocks that are ready to be committed", - )?, - )), + )?; + + Ok(Arc::new(fork_chain)) + } } } diff --git a/zebra-state/src/service/non_finalized_state/chain.rs b/zebra-state/src/service/non_finalized_state/chain.rs index 27903148..01a6b059 100644 --- a/zebra-state/src/service/non_finalized_state/chain.rs +++ b/zebra-state/src/service/non_finalized_state/chain.rs @@ -189,18 +189,31 @@ impl Chain { /// Push a contextually valid non-finalized block into this chain as the new tip. /// - /// If the block is invalid, drop this chain and return an error. + /// If the block is invalid, clears the chain, and returns an error. /// /// Note: a [`ContextuallyValidBlock`] isn't actually contextually valid until /// [`update_chain_state_with`] returns success. #[instrument(level = "debug", skip(self, block), fields(block = %block.block))] - pub fn push(mut self, block: ContextuallyValidBlock) -> Result { + pub fn push(&mut self, block: ContextuallyValidBlock) -> Result<(), ValidateContextError> { // update cumulative data members - self.update_chain_tip_with(&block)?; + if let Err(error) = self.update_chain_tip_with(&block) { + // The chain could be in an invalid half-updated state, so clear its data. + *self = Chain::new( + self.network, + sprout::tree::NoteCommitmentTree::default(), + sapling::tree::NoteCommitmentTree::default(), + orchard::tree::NoteCommitmentTree::default(), + HistoryTree::default(), + ValueBalance::zero(), + ); + + return Err(error); + } + tracing::debug!(block = %block.block, "adding block to chain"); self.blocks.insert(block.height, block); - Ok(self) + Ok(()) } /// Remove the lowest height block of the non-finalized portion of a chain. @@ -308,6 +321,7 @@ impl Chain { Ok(Some(forked)) } + /// Returns the block hash of the tip block. pub fn non_finalized_tip_hash(&self) -> block::Hash { self.blocks .values() @@ -316,6 +330,23 @@ impl Chain { .hash } + /// Returns the block hash of the non-finalized root block. + pub fn non_finalized_root_hash(&self) -> block::Hash { + self.blocks + .values() + .next() + .expect("only called while blocks is populated") + .hash + } + + /// Returns the block hash of the `n`th block from the non-finalized root. + /// + /// This is the block at `lowest_height() + n`. + #[allow(dead_code)] + pub fn non_finalized_nth_hash(&self, n: usize) -> Option { + self.blocks.values().nth(n).map(|block| block.hash) + } + /// Remove the highest height block of the non-finalized portion of a chain. fn pop_tip(&mut self) { let block_height = self.non_finalized_tip_height(); @@ -350,10 +381,17 @@ impl Chain { self.blocks.keys().next_back().cloned() } + /// Returns true if the non-finalized part of this chain is empty. pub fn is_empty(&self) -> bool { self.blocks.is_empty() } + /// Returns the non-finalized length of this chain. + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.blocks.len() + } + /// Returns the unspent transaction outputs (UTXOs) in this non-finalized chain. /// /// Callers should also check the finalized state for available UTXOs. diff --git a/zebra-state/src/service/non_finalized_state/queued_blocks.rs b/zebra-state/src/service/non_finalized_state/queued_blocks.rs index 839cc9fe..79ed471f 100644 --- a/zebra-state/src/service/non_finalized_state/queued_blocks.rs +++ b/zebra-state/src/service/non_finalized_state/queued_blocks.rs @@ -9,7 +9,7 @@ use zebra_chain::{block, transparent}; use crate::service::QueuedBlock; /// A queue of blocks, awaiting the arrival of parent blocks. -#[derive(Default)] +#[derive(Debug, Default)] pub struct QueuedBlocks { /// Blocks awaiting their parent blocks for contextual verification. blocks: HashMap, diff --git a/zebra-state/src/service/non_finalized_state/tests/prop.rs b/zebra-state/src/service/non_finalized_state/tests/prop.rs index a185b464..9cb5da21 100644 --- a/zebra-state/src/service/non_finalized_state/tests/prop.rs +++ b/zebra-state/src/service/non_finalized_state/tests/prop.rs @@ -1,3 +1,5 @@ +//! Randomised property tests for the non-finalized state. + use std::{collections::BTreeMap, env, sync::Arc}; use zebra_test::prelude::*; @@ -62,7 +64,7 @@ fn push_genesis_chain() -> Result<()> { chain_values.insert(block.height.into(), (block.chain_value_pool_change.into(), None)); - only_chain = only_chain + only_chain .push(block.clone()) .map_err(|e| (e, chain_values.clone())) .expect("invalid chain value pools"); @@ -103,7 +105,7 @@ fn push_history_tree_chain() -> Result<()> { .iter() .take(count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - only_chain = only_chain.push(block)?; + only_chain.push(block)?; } prop_assert_eq!(only_chain.blocks.len(), count); @@ -152,7 +154,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { block, partial_chain.unspent_utxos(), )?; - partial_chain = partial_chain + partial_chain .push(block) .expect("partial chain push is valid"); } @@ -169,7 +171,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { for block in chain.iter().cloned() { let block = ContextuallyValidBlock::with_block_and_spent_utxos(block, full_chain.unspent_utxos())?; - full_chain = full_chain + full_chain .push(block.clone()) .expect("full chain push is valid"); @@ -219,7 +221,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { for block in chain.iter().skip(fork_at_count).cloned() { let block = ContextuallyValidBlock::with_block_and_spent_utxos(block, forked.unspent_utxos())?; - forked = forked.push(block).expect("forked chain push is valid"); + forked.push(block).expect("forked chain push is valid"); } prop_assert_eq!(forked.blocks.len(), full_chain.blocks.len()); @@ -259,13 +261,13 @@ fn forked_equals_pushed_history_tree() -> Result<()> { .iter() .take(fork_at_count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - partial_chain = partial_chain.push(block)?; + partial_chain.push(block)?; } for block in chain .iter() .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - full_chain = full_chain.push(block.clone())?; + full_chain.push(block.clone())?; } let mut forked = full_chain @@ -289,7 +291,7 @@ fn forked_equals_pushed_history_tree() -> Result<()> { .iter() .skip(fork_at_count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - forked = forked.push(block)?; + forked.push(block)?; } prop_assert_eq!(forked.blocks.len(), full_chain.blocks.len()); @@ -326,7 +328,7 @@ fn finalized_equals_pushed_genesis() -> Result<()> { .iter() .take(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain = full_chain.push(block)?; + full_chain.push(block)?; } let mut partial_chain = Chain::new( @@ -341,14 +343,14 @@ fn finalized_equals_pushed_genesis() -> Result<()> { .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - partial_chain = partial_chain.push(block.clone())?; + partial_chain.push(block.clone())?; } for block in chain .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain = full_chain.push(block.clone())?; + full_chain.push(block.clone())?; } for _ in 0..finalized_count { @@ -396,7 +398,7 @@ fn finalized_equals_pushed_history_tree() -> Result<()> { .iter() .take(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain = full_chain.push(block)?; + full_chain.push(block)?; } let mut partial_chain = Chain::new( @@ -412,14 +414,14 @@ fn finalized_equals_pushed_history_tree() -> Result<()> { .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - partial_chain = partial_chain.push(block.clone())?; + partial_chain.push(block.clone())?; } for block in chain .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain = full_chain.push(block.clone())?; + full_chain.push(block.clone())?; } for _ in 0..finalized_count { @@ -561,8 +563,8 @@ fn different_blocks_different_chains() -> Result<()> { } else { Default::default() }; - let chain1 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree1, ValueBalance::fake_populated_pool()); - let chain2 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree2, ValueBalance::fake_populated_pool()); + let mut chain1 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree1, ValueBalance::fake_populated_pool()); + let mut chain2 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree2, ValueBalance::fake_populated_pool()); let block1 = vec1[1].clone().prepare().test_with_zero_spent_utxos(); let block2 = vec2[1].clone().prepare().test_with_zero_spent_utxos(); @@ -570,8 +572,8 @@ fn different_blocks_different_chains() -> Result<()> { let result1 = chain1.push(block1.clone()); let result2 = chain2.push(block2.clone()); - // if there is an error, we don't get the chains back - if let (Ok(mut chain1), Ok(chain2)) = (result1, result2) { + // if there is an error, the chains come back empty + if result1.is_ok() && result2.is_ok() { if block1 == block2 { // the blocks were equal, so the chains should be equal diff --git a/zebra-state/src/service/non_finalized_state/tests/vectors.rs b/zebra-state/src/service/non_finalized_state/tests/vectors.rs index 87a9952f..7a3d326b 100644 --- a/zebra-state/src/service/non_finalized_state/tests/vectors.rs +++ b/zebra-state/src/service/non_finalized_state/tests/vectors.rs @@ -1,3 +1,5 @@ +//! Fixed test vectors for the non-finalized state. + use std::sync::Arc; use zebra_chain::{ @@ -48,7 +50,7 @@ fn construct_single() -> Result<()> { ValueBalance::fake_populated_pool(), ); - chain = chain.push(block.prepare().test_with_zero_spent_utxos())?; + chain.push(block.prepare().test_with_zero_spent_utxos())?; assert_eq!(1, chain.blocks.len()); @@ -79,7 +81,7 @@ fn construct_many() -> Result<()> { ); for block in blocks { - chain = chain.push(block.prepare().test_with_zero_spent_utxos())?; + chain.push(block.prepare().test_with_zero_spent_utxos())?; } assert_eq!(100, chain.blocks.len()); @@ -103,7 +105,7 @@ fn ord_matches_work() -> Result<()> { Default::default(), ValueBalance::fake_populated_pool(), ); - lesser_chain = lesser_chain.push(less_block.prepare().test_with_zero_spent_utxos())?; + lesser_chain.push(less_block.prepare().test_with_zero_spent_utxos())?; let mut bigger_chain = Chain::new( Network::Mainnet, @@ -113,7 +115,7 @@ fn ord_matches_work() -> Result<()> { Default::default(), ValueBalance::zero(), ); - bigger_chain = bigger_chain.push(more_block.prepare().test_with_zero_spent_utxos())?; + bigger_chain.push(more_block.prepare().test_with_zero_spent_utxos())?; assert!(bigger_chain > lesser_chain);