diff --git a/zebra-chain/src/amount.rs b/zebra-chain/src/amount.rs index 0be893bb..28e830f9 100644 --- a/zebra-chain/src/amount.rs +++ b/zebra-chain/src/amount.rs @@ -79,7 +79,10 @@ where type Output = Result>; fn add(self, rhs: Amount) -> Self::Output { - let value = self.0 + rhs.0; + let value = self + .0 + .checked_add(rhs.0) + .expect("adding two constrained Amounts is always within an i64"); value.try_into() } } @@ -125,7 +128,10 @@ where type Output = Result>; fn sub(self, rhs: Amount) -> Self::Output { - let value = self.0 - rhs.0; + let value = self + .0 + .checked_sub(rhs.0) + .expect("subtracting two constrained Amounts is always within an i64"); value.try_into() } } @@ -172,7 +178,7 @@ impl From> for i64 { impl From> for u64 { fn from(amount: Amount) -> Self { - amount.0 as _ + amount.0.try_into().expect("non-negative i64 fits in u64") } } @@ -180,9 +186,14 @@ impl From> for jubjub::Fr { fn from(a: Amount) -> jubjub::Fr { // TODO: this isn't constant time -- does that matter? if a.0 < 0 { - jubjub::Fr::from(a.0.abs() as u64).neg() + let abs_amount = i128::from(a.0) + .checked_abs() + .expect("absolute i64 fits in i128"); + let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64"); + + jubjub::Fr::from(abs_amount).neg() } else { - jubjub::Fr::from(a.0 as u64) + jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64")) } } } @@ -191,13 +202,31 @@ impl From> for halo2::pasta::pallas::Scalar { fn from(a: Amount) -> halo2::pasta::pallas::Scalar { // TODO: this isn't constant time -- does that matter? if a.0 < 0 { - halo2::pasta::pallas::Scalar::from(a.0.abs() as u64).neg() + let abs_amount = i128::from(a.0) + .checked_abs() + .expect("absolute i64 fits in i128"); + let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64"); + + halo2::pasta::pallas::Scalar::from(abs_amount).neg() } else { - halo2::pasta::pallas::Scalar::from(a.0 as u64) + halo2::pasta::pallas::Scalar::from( + u64::try_from(a.0).expect("non-negative i64 fits in u64"), + ) } } } +impl TryFrom for Amount +where + C: Constraint, +{ + type Error = Error; + + fn try_from(value: i32) -> Result { + C::validate(value.into()).map(|v| Self(v, PhantomData)) + } +} + impl TryFrom for Amount where C: Constraint, @@ -209,17 +238,6 @@ where } } -impl TryFrom for Amount -where - C: Constraint, -{ - type Error = Error; - - fn try_from(value: i32) -> Result { - C::validate(value as _).map(|v| Self(v, PhantomData)) - } -} - impl TryFrom for Amount where C: Constraint, @@ -227,6 +245,25 @@ where type Error = Error; fn try_from(value: u64) -> Result { + let value = value.try_into().map_err(|source| Error::Convert { + value: value.into(), + source, + })?; + + C::validate(value).map(|v| Self(v, PhantomData)) + } +} + +/// Conversion from `i128` to `Amount`. +/// +/// Used to handle the result of multiplying negative `Amount`s by `u64`. +impl TryFrom for Amount +where + C: Constraint, +{ + type Error = Error; + + fn try_from(value: i128) -> Result { let value = value .try_into() .map_err(|source| Error::Convert { value, source })?; @@ -260,8 +297,7 @@ impl PartialEq> for i64 { } } -impl Eq for Amount {} -impl Eq for Amount {} +impl Eq for Amount {} impl PartialOrd> for Amount { fn partial_cmp(&self, other: &Amount) -> Option { @@ -269,47 +305,54 @@ impl PartialOrd> for Amount { } } -impl Ord for Amount { - fn cmp(&self, other: &Amount) -> Ordering { +impl Ord for Amount { + fn cmp(&self, other: &Amount) -> Ordering { self.0.cmp(&other.0) } } -impl Ord for Amount { - fn cmp(&self, other: &Amount) -> Ordering { - self.0.cmp(&other.0) - } -} - -impl std::ops::Mul for Amount { - type Output = Result>; +impl std::ops::Mul for Amount +where + C: Constraint, +{ + type Output = Result>; fn mul(self, rhs: u64) -> Self::Output { - let value = (self.0 as u64) - .checked_mul(rhs) - .ok_or(Error::MultiplicationOverflow { - amount: self.0, - multiplier: rhs, - })?; - value.try_into() + // use i128 for multiplication, so we can handle negative Amounts + let value = i128::from(self.0) + .checked_mul(i128::from(rhs)) + .expect("multiplying i64 by u64 can't overflow i128"); + + value.try_into().map_err(|_| Error::MultiplicationOverflow { + amount: self.0, + multiplier: rhs, + overflowing_result: value, + }) } } -impl std::ops::Mul> for u64 { - type Output = Result>; +impl std::ops::Mul> for u64 +where + C: Constraint, +{ + type Output = Result>; - fn mul(self, rhs: Amount) -> Self::Output { + fn mul(self, rhs: Amount) -> Self::Output { rhs.mul(self) } } -impl std::ops::Div for Amount { - type Output = Result>; +impl std::ops::Div for Amount +where + C: Constraint, +{ + type Output = Result>; fn div(self, rhs: u64) -> Self::Output { - let quotient = (self.0 as u64) - .checked_div(rhs) + let quotient = i128::from(self.0) + .checked_div(i128::from(rhs)) .ok_or(Error::DivideByZero { amount: self.0 })?; + Ok(quotient .try_into() .expect("division by a positive integer always stays within the constraint")) @@ -320,14 +363,16 @@ impl std::iter::Sum> for Result> where C: Constraint, { - fn sum>>(iter: I) -> Self { - let sum = iter - .map(|a| a.0) - .try_fold(0i64, |acc, amount| acc.checked_add(amount)); + fn sum>>(mut iter: I) -> Self { + let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount); match sum { - Some(sum) => Amount::try_from(sum), - None => Err(Error::SumOverflow), + Ok(sum) => Ok(sum), + Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow { + partial_sum: value, + remaining_items: iter.count(), + }), + Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error), } } } @@ -352,27 +397,37 @@ where } } -#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)] +#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)] #[allow(missing_docs)] /// Errors that can be returned when validating `Amount`s pub enum Error { /// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?} - Contains { - range: RangeInclusive, + Constraint { value: i64, + range: RangeInclusive, }, - /// u64 {value} could not be converted to an i64 Amount + + /// {value} could not be converted to an i64 Amount Convert { - value: u64, + value: i128, source: std::num::TryFromIntError, }, - /// i64 overflow when multiplying i64 non-negative amount {amount} by u64 {multiplier} - MultiplicationOverflow { amount: i64, multiplier: u64 }, + + /// i64 overflow when multiplying i64 amount {amount} by u64 {multiplier}, overflowing result {overflowing_result} + MultiplicationOverflow { + amount: i64, + multiplier: u64, + overflowing_result: i128, + }, + /// cannot divide amount {amount} by zero DivideByZero { amount: i64 }, - /// i64 overflow when summing i64 amounts - SumOverflow, + /// i64 overflow when summing i64 amounts, partial_sum: {partial_sum}, remaining items: {remaining_items} + SumOverflow { + partial_sum: i64, + remaining_items: usize, + }, } /// Marker type for `Amount` that allows negative values. @@ -427,7 +482,7 @@ pub trait Constraint { let range = Self::valid_range(); if !range.contains(&value) { - Err(Error::Contains { range, value }) + Err(Error::Constraint { value, range }) } else { Ok(value) } @@ -778,9 +833,9 @@ mod test { assert_eq!(sum_ref, sum_value); assert_eq!( sum_ref, - Err(Error::Contains { - range: -MAX_MONEY..=MAX_MONEY, - value: integer_sum, + Err(Error::SumOverflow { + partial_sum: integer_sum, + remaining_items: 0 }) ); @@ -795,9 +850,9 @@ mod test { assert_eq!(sum_ref, sum_value); assert_eq!( sum_ref, - Err(Error::Contains { - range: -MAX_MONEY..=MAX_MONEY, - value: integer_sum, + Err(Error::SumOverflow { + partial_sum: integer_sum, + remaining_items: 0 }) ); @@ -813,7 +868,13 @@ mod test { let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); - assert_eq!(sum_ref, Err(Error::SumOverflow)); + assert_eq!( + sum_ref, + Err(Error::SumOverflow { + partial_sum: 4200000000000000, + remaining_items: 4391 + }) + ); // below min of i64 overflow let times: usize = (i64::MAX / MAX_MONEY) @@ -827,7 +888,13 @@ mod test { let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); - assert_eq!(sum_ref, Err(Error::SumOverflow)); + assert_eq!( + sum_ref, + Err(Error::SumOverflow { + partial_sum: -4200000000000000, + remaining_items: 4391 + }) + ); Ok(()) } diff --git a/zebra-chain/src/value_balance.rs b/zebra-chain/src/value_balance.rs index 834b4458..bfaf17cb 100644 --- a/zebra-chain/src/value_balance.rs +++ b/zebra-chain/src/value_balance.rs @@ -1,6 +1,6 @@ //! A type that can hold the four types of Zcash value pools. -use crate::amount::{Amount, Constraint, Error, NonNegative}; +use crate::amount::{Amount, Constraint, Error, NegativeAllowed, NonNegative}; use std::convert::TryInto; @@ -10,8 +10,10 @@ mod arbitrary; #[cfg(test)] mod tests; +use ValueBalanceError::*; + /// An amount spread between different Zcash pools. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ValueBalance { transparent: Amount, sprout: Amount, @@ -131,6 +133,20 @@ where } } + /// Convert this value balance to a different ValueBalance type, + /// if it satisfies the new constraint + pub fn constrain(self) -> Result, ValueBalanceError> + where + C2: Constraint, + { + Ok(ValueBalance:: { + transparent: self.transparent.constrain().map_err(Transparent)?, + sprout: self.sprout.constrain().map_err(Sprout)?, + sapling: self.sapling.constrain().map_err(Sapling)?, + orchard: self.orchard.constrain().map_err(Orchard)?, + }) + } + /// To byte array pub fn to_bytes(self) -> [u8; 32] { let transparent = self.transparent.to_bytes(); @@ -151,22 +167,29 @@ where bytes[0..8] .try_into() .expect("Extracting the first quarter of a [u8; 32] should always succeed"), - )?; + ) + .map_err(Transparent)?; + let sprout = Amount::from_bytes( bytes[8..16] .try_into() .expect("Extracting the second quarter of a [u8; 32] should always succeed"), - )?; + ) + .map_err(Sprout)?; + let sapling = Amount::from_bytes( bytes[16..24] .try_into() .expect("Extracting the third quarter of a [u8; 32] should always succeed"), - )?; + ) + .map_err(Sapling)?; + let orchard = Amount::from_bytes( bytes[24..32] .try_into() .expect("Extracting the last quarter of a [u8; 32] should always succeed"), - )?; + ) + .map_err(Orchard)?; Ok(ValueBalance { transparent, @@ -177,12 +200,20 @@ where } } -#[derive(thiserror::Error, Debug, Clone, PartialEq)] -/// Errors that can be returned when validating a [`ValueBalance`]. +#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)] +/// Errors that can be returned when validating a [`ValueBalance`] pub enum ValueBalanceError { - #[error("value balance contains invalid amounts")] - /// Any error related to [`Amount`]s inside the [`ValueBalance`] - AmountError(#[from] Error), + /// transparent amount error {0} + Transparent(Error), + + /// sprout amount error {0} + Sprout(Error), + + /// sapling amount error {0} + Sapling(Error), + + /// orchard amount error {0} + Orchard(Error), } impl std::ops::Add for ValueBalance @@ -192,10 +223,10 @@ where type Output = Result, ValueBalanceError>; fn add(self, rhs: ValueBalance) -> Self::Output { Ok(ValueBalance:: { - transparent: (self.transparent + rhs.transparent)?, - sprout: (self.sprout + rhs.sprout)?, - sapling: (self.sapling + rhs.sapling)?, - orchard: (self.orchard + rhs.orchard)?, + transparent: (self.transparent + rhs.transparent).map_err(Transparent)?, + sprout: (self.sprout + rhs.sprout).map_err(Sprout)?, + sapling: (self.sapling + rhs.sapling).map_err(Sapling)?, + orchard: (self.orchard + rhs.orchard).map_err(Orchard)?, }) } } @@ -216,10 +247,10 @@ where type Output = Result, ValueBalanceError>; fn sub(self, rhs: ValueBalance) -> Self::Output { Ok(ValueBalance:: { - transparent: (self.transparent - rhs.transparent)?, - sprout: (self.sprout - rhs.sprout)?, - sapling: (self.sapling - rhs.sapling)?, - orchard: (self.orchard - rhs.orchard)?, + transparent: (self.transparent - rhs.transparent).map_err(Transparent)?, + sprout: (self.sprout - rhs.sprout).map_err(Sprout)?, + sapling: (self.sapling - rhs.sapling).map_err(Sapling)?, + orchard: (self.orchard - rhs.orchard).map_err(Orchard)?, }) } } @@ -239,12 +270,7 @@ where { fn sum>>(mut iter: I) -> Self { iter.try_fold(ValueBalance::zero(), |acc, value_balance| { - Ok(ValueBalance { - transparent: (acc.transparent + value_balance.transparent)?, - sprout: (acc.sprout + value_balance.sprout)?, - sapling: (acc.sapling + value_balance.sapling)?, - orchard: (acc.orchard + value_balance.orchard)?, - }) + acc + value_balance }) } } @@ -257,3 +283,19 @@ where iter.copied().sum() } } + +impl std::ops::Neg for ValueBalance +where + C: Constraint, +{ + type Output = ValueBalance; + + fn neg(self) -> Self::Output { + ValueBalance:: { + transparent: self.transparent.neg(), + sprout: self.sprout.neg(), + sapling: self.sapling.neg(), + orchard: self.orchard.neg(), + } + } +} diff --git a/zebra-chain/src/value_balance/tests/prop.rs b/zebra-chain/src/value_balance/tests/prop.rs index 9806d431..f55b22d4 100644 --- a/zebra-chain/src/value_balance/tests/prop.rs +++ b/zebra-chain/src/value_balance/tests/prop.rs @@ -29,7 +29,11 @@ proptest! { ), _ => prop_assert!( matches!( - value_balance1 + value_balance2, Err(ValueBalanceError::AmountError(_)) + value_balance1 + value_balance2, + Err(ValueBalanceError::Transparent(_) + | ValueBalanceError::Sprout(_) + | ValueBalanceError::Sapling(_) + | ValueBalanceError::Orchard(_)) ) ), } @@ -58,7 +62,11 @@ proptest! { ), _ => prop_assert!( matches!( - value_balance1 - value_balance2, Err(ValueBalanceError::AmountError(_)) + value_balance1 - value_balance2, + Err(ValueBalanceError::Transparent(_) + | ValueBalanceError::Sprout(_) + | ValueBalanceError::Sapling(_) + | ValueBalanceError::Orchard(_)) ) ), } @@ -88,7 +96,12 @@ proptest! { orchard, }) ), - _ => prop_assert!(matches!(collection.iter().sum(), Err(ValueBalanceError::AmountError(_)))) + _ => prop_assert!(matches!(collection.iter().sum(), + Err(ValueBalanceError::Transparent(_) + | ValueBalanceError::Sprout(_) + | ValueBalanceError::Sapling(_) + | ValueBalanceError::Orchard(_)) + )) } }