Add extra arithmetic operations and error details to Amount and ValueBalance (#2577)

* Make Amount arithmetic more generic

To modify generated amounts, we need some extra operations on `Amount`.

We also need to extend existing operations to both `NonNegative` and
`NegativeAllowed` amounts.

* Add a constrain method for ValueBalance

* Derive Eq for ValueBalance

* impl Neg for ValueBalance

* Make some Amount arithmetic expectations explicit

* Explain why we use i128 for multiplication

And expand the overflow error details.

* Expand Amount::sum error details

* Make amount::Error field order consistent

* Rename an amount::Error variant to Constraint, so it's clearer

* Add specific pool variants to ValueBalanceError
This commit is contained in:
teor 2021-08-09 23:13:27 +10:00 committed by GitHub
parent 910f0ff5dc
commit fc68240fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 218 additions and 96 deletions

View File

@ -79,7 +79,10 @@ where
type Output = Result<Amount<C>>; type Output = Result<Amount<C>>;
fn add(self, rhs: Amount<C>) -> Self::Output { fn add(self, rhs: Amount<C>) -> Self::Output {
let value = self.0 + rhs.0; let value = self
.0
.checked_add(rhs.0)
.expect("adding two constrained Amounts is always within an i64");
value.try_into() value.try_into()
} }
} }
@ -125,7 +128,10 @@ where
type Output = Result<Amount<C>>; type Output = Result<Amount<C>>;
fn sub(self, rhs: Amount<C>) -> Self::Output { fn sub(self, rhs: Amount<C>) -> Self::Output {
let value = self.0 - rhs.0; let value = self
.0
.checked_sub(rhs.0)
.expect("subtracting two constrained Amounts is always within an i64");
value.try_into() value.try_into()
} }
} }
@ -172,7 +178,7 @@ impl<C> From<Amount<C>> for i64 {
impl From<Amount<NonNegative>> for u64 { impl From<Amount<NonNegative>> for u64 {
fn from(amount: Amount<NonNegative>) -> Self { fn from(amount: Amount<NonNegative>) -> Self {
amount.0 as _ amount.0.try_into().expect("non-negative i64 fits in u64")
} }
} }
@ -180,9 +186,14 @@ impl<C> From<Amount<C>> for jubjub::Fr {
fn from(a: Amount<C>) -> jubjub::Fr { fn from(a: Amount<C>) -> jubjub::Fr {
// TODO: this isn't constant time -- does that matter? // TODO: this isn't constant time -- does that matter?
if a.0 < 0 { if a.0 < 0 {
jubjub::Fr::from(a.0.abs() as u64).neg() let abs_amount = i128::from(a.0)
.checked_abs()
.expect("absolute i64 fits in i128");
let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
jubjub::Fr::from(abs_amount).neg()
} else { } else {
jubjub::Fr::from(a.0 as u64) jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
} }
} }
} }
@ -191,13 +202,31 @@ impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar { fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
// TODO: this isn't constant time -- does that matter? // TODO: this isn't constant time -- does that matter?
if a.0 < 0 { if a.0 < 0 {
halo2::pasta::pallas::Scalar::from(a.0.abs() as u64).neg() let abs_amount = i128::from(a.0)
.checked_abs()
.expect("absolute i64 fits in i128");
let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
halo2::pasta::pallas::Scalar::from(abs_amount).neg()
} else { } else {
halo2::pasta::pallas::Scalar::from(a.0 as u64) halo2::pasta::pallas::Scalar::from(
u64::try_from(a.0).expect("non-negative i64 fits in u64"),
)
} }
} }
} }
impl<C> TryFrom<i32> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
C::validate(value.into()).map(|v| Self(v, PhantomData))
}
}
impl<C> TryFrom<i64> for Amount<C> impl<C> TryFrom<i64> for Amount<C>
where where
C: Constraint, C: Constraint,
@ -209,17 +238,6 @@ where
} }
} }
impl<C> TryFrom<i32> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
C::validate(value as _).map(|v| Self(v, PhantomData))
}
}
impl<C> TryFrom<u64> for Amount<C> impl<C> TryFrom<u64> for Amount<C>
where where
C: Constraint, C: Constraint,
@ -227,6 +245,25 @@ where
type Error = Error; type Error = Error;
fn try_from(value: u64) -> Result<Self, Self::Error> { fn try_from(value: u64) -> Result<Self, Self::Error> {
let value = value.try_into().map_err(|source| Error::Convert {
value: value.into(),
source,
})?;
C::validate(value).map(|v| Self(v, PhantomData))
}
}
/// Conversion from `i128` to `Amount`.
///
/// Used to handle the result of multiplying negative `Amount`s by `u64`.
impl<C> TryFrom<i128> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i128) -> Result<Self, Self::Error> {
let value = value let value = value
.try_into() .try_into()
.map_err(|source| Error::Convert { value, source })?; .map_err(|source| Error::Convert { value, source })?;
@ -260,8 +297,7 @@ impl<C> PartialEq<Amount<C>> for i64 {
} }
} }
impl Eq for Amount<NegativeAllowed> {} impl<C> Eq for Amount<C> {}
impl Eq for Amount<NonNegative> {}
impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> { impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> { fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
@ -269,47 +305,54 @@ impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
} }
} }
impl Ord for Amount<NegativeAllowed> { impl<C> Ord for Amount<C> {
fn cmp(&self, other: &Amount<NegativeAllowed>) -> Ordering { fn cmp(&self, other: &Amount<C>) -> Ordering {
self.0.cmp(&other.0) self.0.cmp(&other.0)
} }
} }
impl Ord for Amount<NonNegative> { impl<C> std::ops::Mul<u64> for Amount<C>
fn cmp(&self, other: &Amount<NonNegative>) -> Ordering { where
self.0.cmp(&other.0) C: Constraint,
} {
} type Output = Result<Amount<C>>;
impl std::ops::Mul<u64> for Amount<NonNegative> {
type Output = Result<Amount<NonNegative>>;
fn mul(self, rhs: u64) -> Self::Output { fn mul(self, rhs: u64) -> Self::Output {
let value = (self.0 as u64) // use i128 for multiplication, so we can handle negative Amounts
.checked_mul(rhs) let value = i128::from(self.0)
.ok_or(Error::MultiplicationOverflow { .checked_mul(i128::from(rhs))
amount: self.0, .expect("multiplying i64 by u64 can't overflow i128");
multiplier: rhs,
})?; value.try_into().map_err(|_| Error::MultiplicationOverflow {
value.try_into() amount: self.0,
multiplier: rhs,
overflowing_result: value,
})
} }
} }
impl std::ops::Mul<Amount<NonNegative>> for u64 { impl<C> std::ops::Mul<Amount<C>> for u64
type Output = Result<Amount<NonNegative>>; where
C: Constraint,
{
type Output = Result<Amount<C>>;
fn mul(self, rhs: Amount<NonNegative>) -> Self::Output { fn mul(self, rhs: Amount<C>) -> Self::Output {
rhs.mul(self) rhs.mul(self)
} }
} }
impl std::ops::Div<u64> for Amount<NonNegative> { impl<C> std::ops::Div<u64> for Amount<C>
type Output = Result<Amount<NonNegative>>; where
C: Constraint,
{
type Output = Result<Amount<C>>;
fn div(self, rhs: u64) -> Self::Output { fn div(self, rhs: u64) -> Self::Output {
let quotient = (self.0 as u64) let quotient = i128::from(self.0)
.checked_div(rhs) .checked_div(i128::from(rhs))
.ok_or(Error::DivideByZero { amount: self.0 })?; .ok_or(Error::DivideByZero { amount: self.0 })?;
Ok(quotient Ok(quotient
.try_into() .try_into()
.expect("division by a positive integer always stays within the constraint")) .expect("division by a positive integer always stays within the constraint"))
@ -320,14 +363,16 @@ impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
where where
C: Constraint, C: Constraint,
{ {
fn sum<I: Iterator<Item = Amount<C>>>(iter: I) -> Self { fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
let sum = iter let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
.map(|a| a.0)
.try_fold(0i64, |acc, amount| acc.checked_add(amount));
match sum { match sum {
Some(sum) => Amount::try_from(sum), Ok(sum) => Ok(sum),
None => Err(Error::SumOverflow), Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
partial_sum: value,
remaining_items: iter.count(),
}),
Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
} }
} }
} }
@ -352,27 +397,37 @@ where
} }
} }
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)] #[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)]
#[allow(missing_docs)] #[allow(missing_docs)]
/// Errors that can be returned when validating `Amount`s /// Errors that can be returned when validating `Amount`s
pub enum Error { pub enum Error {
/// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?} /// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}
Contains { Constraint {
range: RangeInclusive<i64>,
value: i64, value: i64,
range: RangeInclusive<i64>,
}, },
/// u64 {value} could not be converted to an i64 Amount
/// {value} could not be converted to an i64 Amount
Convert { Convert {
value: u64, value: i128,
source: std::num::TryFromIntError, source: std::num::TryFromIntError,
}, },
/// i64 overflow when multiplying i64 non-negative amount {amount} by u64 {multiplier}
MultiplicationOverflow { amount: i64, multiplier: u64 }, /// i64 overflow when multiplying i64 amount {amount} by u64 {multiplier}, overflowing result {overflowing_result}
MultiplicationOverflow {
amount: i64,
multiplier: u64,
overflowing_result: i128,
},
/// cannot divide amount {amount} by zero /// cannot divide amount {amount} by zero
DivideByZero { amount: i64 }, DivideByZero { amount: i64 },
/// i64 overflow when summing i64 amounts /// i64 overflow when summing i64 amounts, partial_sum: {partial_sum}, remaining items: {remaining_items}
SumOverflow, SumOverflow {
partial_sum: i64,
remaining_items: usize,
},
} }
/// Marker type for `Amount` that allows negative values. /// Marker type for `Amount` that allows negative values.
@ -427,7 +482,7 @@ pub trait Constraint {
let range = Self::valid_range(); let range = Self::valid_range();
if !range.contains(&value) { if !range.contains(&value) {
Err(Error::Contains { range, value }) Err(Error::Constraint { value, range })
} else { } else {
Ok(value) Ok(value)
} }
@ -778,9 +833,9 @@ mod test {
assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, sum_value);
assert_eq!( assert_eq!(
sum_ref, sum_ref,
Err(Error::Contains { Err(Error::SumOverflow {
range: -MAX_MONEY..=MAX_MONEY, partial_sum: integer_sum,
value: integer_sum, remaining_items: 0
}) })
); );
@ -795,9 +850,9 @@ mod test {
assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, sum_value);
assert_eq!( assert_eq!(
sum_ref, sum_ref,
Err(Error::Contains { Err(Error::SumOverflow {
range: -MAX_MONEY..=MAX_MONEY, partial_sum: integer_sum,
value: integer_sum, remaining_items: 0
}) })
); );
@ -813,7 +868,13 @@ mod test {
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>(); let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow)); assert_eq!(
sum_ref,
Err(Error::SumOverflow {
partial_sum: 4200000000000000,
remaining_items: 4391
})
);
// below min of i64 overflow // below min of i64 overflow
let times: usize = (i64::MAX / MAX_MONEY) let times: usize = (i64::MAX / MAX_MONEY)
@ -827,7 +888,13 @@ mod test {
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>(); let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow)); assert_eq!(
sum_ref,
Err(Error::SumOverflow {
partial_sum: -4200000000000000,
remaining_items: 4391
})
);
Ok(()) Ok(())
} }

View File

@ -1,6 +1,6 @@
//! A type that can hold the four types of Zcash value pools. //! A type that can hold the four types of Zcash value pools.
use crate::amount::{Amount, Constraint, Error, NonNegative}; use crate::amount::{Amount, Constraint, Error, NegativeAllowed, NonNegative};
use std::convert::TryInto; use std::convert::TryInto;
@ -10,8 +10,10 @@ mod arbitrary;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use ValueBalanceError::*;
/// An amount spread between different Zcash pools. /// An amount spread between different Zcash pools.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ValueBalance<C> { pub struct ValueBalance<C> {
transparent: Amount<C>, transparent: Amount<C>,
sprout: Amount<C>, sprout: Amount<C>,
@ -131,6 +133,20 @@ where
} }
} }
/// Convert this value balance to a different ValueBalance type,
/// if it satisfies the new constraint
pub fn constrain<C2>(self) -> Result<ValueBalance<C2>, ValueBalanceError>
where
C2: Constraint,
{
Ok(ValueBalance::<C2> {
transparent: self.transparent.constrain().map_err(Transparent)?,
sprout: self.sprout.constrain().map_err(Sprout)?,
sapling: self.sapling.constrain().map_err(Sapling)?,
orchard: self.orchard.constrain().map_err(Orchard)?,
})
}
/// To byte array /// To byte array
pub fn to_bytes(self) -> [u8; 32] { pub fn to_bytes(self) -> [u8; 32] {
let transparent = self.transparent.to_bytes(); let transparent = self.transparent.to_bytes();
@ -151,22 +167,29 @@ where
bytes[0..8] bytes[0..8]
.try_into() .try_into()
.expect("Extracting the first quarter of a [u8; 32] should always succeed"), .expect("Extracting the first quarter of a [u8; 32] should always succeed"),
)?; )
.map_err(Transparent)?;
let sprout = Amount::from_bytes( let sprout = Amount::from_bytes(
bytes[8..16] bytes[8..16]
.try_into() .try_into()
.expect("Extracting the second quarter of a [u8; 32] should always succeed"), .expect("Extracting the second quarter of a [u8; 32] should always succeed"),
)?; )
.map_err(Sprout)?;
let sapling = Amount::from_bytes( let sapling = Amount::from_bytes(
bytes[16..24] bytes[16..24]
.try_into() .try_into()
.expect("Extracting the third quarter of a [u8; 32] should always succeed"), .expect("Extracting the third quarter of a [u8; 32] should always succeed"),
)?; )
.map_err(Sapling)?;
let orchard = Amount::from_bytes( let orchard = Amount::from_bytes(
bytes[24..32] bytes[24..32]
.try_into() .try_into()
.expect("Extracting the last quarter of a [u8; 32] should always succeed"), .expect("Extracting the last quarter of a [u8; 32] should always succeed"),
)?; )
.map_err(Orchard)?;
Ok(ValueBalance { Ok(ValueBalance {
transparent, transparent,
@ -177,12 +200,20 @@ where
} }
} }
#[derive(thiserror::Error, Debug, Clone, PartialEq)] #[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)]
/// Errors that can be returned when validating a [`ValueBalance`]. /// Errors that can be returned when validating a [`ValueBalance`]
pub enum ValueBalanceError { pub enum ValueBalanceError {
#[error("value balance contains invalid amounts")] /// transparent amount error {0}
/// Any error related to [`Amount`]s inside the [`ValueBalance`] Transparent(Error),
AmountError(#[from] Error),
/// sprout amount error {0}
Sprout(Error),
/// sapling amount error {0}
Sapling(Error),
/// orchard amount error {0}
Orchard(Error),
} }
impl<C> std::ops::Add for ValueBalance<C> impl<C> std::ops::Add for ValueBalance<C>
@ -192,10 +223,10 @@ where
type Output = Result<ValueBalance<C>, ValueBalanceError>; type Output = Result<ValueBalance<C>, ValueBalanceError>;
fn add(self, rhs: ValueBalance<C>) -> Self::Output { fn add(self, rhs: ValueBalance<C>) -> Self::Output {
Ok(ValueBalance::<C> { Ok(ValueBalance::<C> {
transparent: (self.transparent + rhs.transparent)?, transparent: (self.transparent + rhs.transparent).map_err(Transparent)?,
sprout: (self.sprout + rhs.sprout)?, sprout: (self.sprout + rhs.sprout).map_err(Sprout)?,
sapling: (self.sapling + rhs.sapling)?, sapling: (self.sapling + rhs.sapling).map_err(Sapling)?,
orchard: (self.orchard + rhs.orchard)?, orchard: (self.orchard + rhs.orchard).map_err(Orchard)?,
}) })
} }
} }
@ -216,10 +247,10 @@ where
type Output = Result<ValueBalance<C>, ValueBalanceError>; type Output = Result<ValueBalance<C>, ValueBalanceError>;
fn sub(self, rhs: ValueBalance<C>) -> Self::Output { fn sub(self, rhs: ValueBalance<C>) -> Self::Output {
Ok(ValueBalance::<C> { Ok(ValueBalance::<C> {
transparent: (self.transparent - rhs.transparent)?, transparent: (self.transparent - rhs.transparent).map_err(Transparent)?,
sprout: (self.sprout - rhs.sprout)?, sprout: (self.sprout - rhs.sprout).map_err(Sprout)?,
sapling: (self.sapling - rhs.sapling)?, sapling: (self.sapling - rhs.sapling).map_err(Sapling)?,
orchard: (self.orchard - rhs.orchard)?, orchard: (self.orchard - rhs.orchard).map_err(Orchard)?,
}) })
} }
} }
@ -239,12 +270,7 @@ where
{ {
fn sum<I: Iterator<Item = ValueBalance<C>>>(mut iter: I) -> Self { fn sum<I: Iterator<Item = ValueBalance<C>>>(mut iter: I) -> Self {
iter.try_fold(ValueBalance::zero(), |acc, value_balance| { iter.try_fold(ValueBalance::zero(), |acc, value_balance| {
Ok(ValueBalance { acc + value_balance
transparent: (acc.transparent + value_balance.transparent)?,
sprout: (acc.sprout + value_balance.sprout)?,
sapling: (acc.sapling + value_balance.sapling)?,
orchard: (acc.orchard + value_balance.orchard)?,
})
}) })
} }
} }
@ -257,3 +283,19 @@ where
iter.copied().sum() iter.copied().sum()
} }
} }
impl<C> std::ops::Neg for ValueBalance<C>
where
C: Constraint,
{
type Output = ValueBalance<NegativeAllowed>;
fn neg(self) -> Self::Output {
ValueBalance::<NegativeAllowed> {
transparent: self.transparent.neg(),
sprout: self.sprout.neg(),
sapling: self.sapling.neg(),
orchard: self.orchard.neg(),
}
}
}

View File

@ -29,7 +29,11 @@ proptest! {
), ),
_ => prop_assert!( _ => prop_assert!(
matches!( matches!(
value_balance1 + value_balance2, Err(ValueBalanceError::AmountError(_)) value_balance1 + value_balance2,
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
) )
), ),
} }
@ -58,7 +62,11 @@ proptest! {
), ),
_ => prop_assert!( _ => prop_assert!(
matches!( matches!(
value_balance1 - value_balance2, Err(ValueBalanceError::AmountError(_)) value_balance1 - value_balance2,
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
) )
), ),
} }
@ -88,7 +96,12 @@ proptest! {
orchard, orchard,
}) })
), ),
_ => prop_assert!(matches!(collection.iter().sum(), Err(ValueBalanceError::AmountError(_)))) _ => prop_assert!(matches!(collection.iter().sum(),
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
))
} }
} }