diff --git a/zebra-chain/src/block.rs b/zebra-chain/src/block.rs index b24e3511..966c8097 100644 --- a/zebra-chain/src/block.rs +++ b/zebra-chain/src/block.rs @@ -14,7 +14,7 @@ pub mod arbitrary; #[cfg(any(test, feature = "bench"))] pub mod tests; -use std::fmt; +use std::{collections::HashMap, fmt}; pub use commitment::{ChainHistoryMmrRootHash, Commitment, CommitmentError}; pub use hash::Hash; @@ -28,6 +28,7 @@ pub use arbitrary::LedgerState; use serde::{Deserialize, Serialize}; use crate::{ + amount::NegativeAllowed, fmt::DisplayToDebug, orchard, parameters::{Network, NetworkUpgrade}, @@ -36,6 +37,7 @@ use crate::{ sprout, transaction::Transaction, transparent, + value_balance::{ValueBalance, ValueBalanceError}, }; /// A Zcash block, containing a header and a list of transactions. @@ -143,6 +145,22 @@ impl Block { .map(|transaction| transaction.orchard_nullifiers()) .flatten() } + + /// Get all the value balances from this block by summing all the value balances + /// in each transaction the block has. + /// + /// `utxos` must contain the utxos of every input in the block, + /// including UTXOs created by a transaction in this block, + /// then spent by a later transaction that's also in this block. + pub fn value_balance( + &self, + utxos: &HashMap, + ) -> Result, ValueBalanceError> { + self.transactions + .iter() + .flat_map(|t| t.value_balance(utxos)) + .sum() + } } impl<'a> From<&'a Block> for Hash { diff --git a/zebra-chain/src/value_balance.rs b/zebra-chain/src/value_balance.rs index 526d7ea4..164d7a42 100644 --- a/zebra-chain/src/value_balance.rs +++ b/zebra-chain/src/value_balance.rs @@ -184,3 +184,28 @@ where self? - rhs } } + +impl std::iter::Sum> for Result, ValueBalanceError> +where + C: Constraint + Copy, +{ + fn sum>>(mut iter: I) -> Self { + iter.try_fold(ValueBalance::zero(), |acc, value_balance| { + Ok(ValueBalance { + transparent: (acc.transparent + value_balance.transparent)?, + sprout: (acc.sprout + value_balance.sprout)?, + sapling: (acc.sapling + value_balance.sapling)?, + orchard: (acc.orchard + value_balance.orchard)?, + }) + }) + } +} + +impl<'amt, C> std::iter::Sum<&'amt ValueBalance> for Result, ValueBalanceError> +where + C: Constraint + std::marker::Copy + 'amt, +{ + fn sum>>(iter: I) -> Self { + iter.copied().sum() + } +} diff --git a/zebra-chain/src/value_balance/tests/prop.rs b/zebra-chain/src/value_balance/tests/prop.rs index 5ef4de98..010976ca 100644 --- a/zebra-chain/src/value_balance/tests/prop.rs +++ b/zebra-chain/src/value_balance/tests/prop.rs @@ -60,4 +60,32 @@ proptest! { ), } } + + #[test] + fn test_sum( + value_balance1 in any::>(), + value_balance2 in any::>(), + ) { + zebra_test::init(); + + let collection = vec![value_balance1, value_balance2]; + + let transparent = value_balance1.transparent + value_balance2.transparent; + let sprout = value_balance1.sprout + value_balance2.sprout; + let sapling = value_balance1.sapling + value_balance2.sapling; + let orchard = value_balance1.orchard + value_balance2.orchard; + + match (transparent, sprout, sapling, orchard) { + (Ok(transparent), Ok(sprout), Ok(sapling), Ok(orchard)) => prop_assert_eq!( + collection.iter().sum::, ValueBalanceError>>(), + Ok(ValueBalance { + transparent, + sprout, + sapling, + orchard, + }) + ), + _ => prop_assert!(matches!(collection.iter().sum(), Err(ValueBalanceError::AmountError(_)))) + } + } }