Add extra arithmetic operations and error details to Amount and ValueBalance (#2577)

* Make Amount arithmetic more generic

To modify generated amounts, we need some extra operations on `Amount`.

We also need to extend existing operations to both `NonNegative` and
`NegativeAllowed` amounts.

* Add a constrain method for ValueBalance

* Derive Eq for ValueBalance

* impl Neg for ValueBalance

* Make some Amount arithmetic expectations explicit

* Explain why we use i128 for multiplication

And expand the overflow error details.

* Expand Amount::sum error details

* Make amount::Error field order consistent

* Rename an amount::Error variant to Constraint, so it's clearer

* Add specific pool variants to ValueBalanceError
This commit is contained in:
teor 2021-08-09 23:13:27 +10:00 committed by GitHub
parent 910f0ff5dc
commit fc68240fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 218 additions and 96 deletions

View File

@ -79,7 +79,10 @@ where
type Output = Result<Amount<C>>;
fn add(self, rhs: Amount<C>) -> Self::Output {
let value = self.0 + rhs.0;
let value = self
.0
.checked_add(rhs.0)
.expect("adding two constrained Amounts is always within an i64");
value.try_into()
}
}
@ -125,7 +128,10 @@ where
type Output = Result<Amount<C>>;
fn sub(self, rhs: Amount<C>) -> Self::Output {
let value = self.0 - rhs.0;
let value = self
.0
.checked_sub(rhs.0)
.expect("subtracting two constrained Amounts is always within an i64");
value.try_into()
}
}
@ -172,7 +178,7 @@ impl<C> From<Amount<C>> for i64 {
impl From<Amount<NonNegative>> for u64 {
fn from(amount: Amount<NonNegative>) -> Self {
amount.0 as _
amount.0.try_into().expect("non-negative i64 fits in u64")
}
}
@ -180,9 +186,14 @@ impl<C> From<Amount<C>> for jubjub::Fr {
fn from(a: Amount<C>) -> jubjub::Fr {
// TODO: this isn't constant time -- does that matter?
if a.0 < 0 {
jubjub::Fr::from(a.0.abs() as u64).neg()
let abs_amount = i128::from(a.0)
.checked_abs()
.expect("absolute i64 fits in i128");
let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
jubjub::Fr::from(abs_amount).neg()
} else {
jubjub::Fr::from(a.0 as u64)
jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
}
}
}
@ -191,13 +202,31 @@ impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
// TODO: this isn't constant time -- does that matter?
if a.0 < 0 {
halo2::pasta::pallas::Scalar::from(a.0.abs() as u64).neg()
let abs_amount = i128::from(a.0)
.checked_abs()
.expect("absolute i64 fits in i128");
let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
halo2::pasta::pallas::Scalar::from(abs_amount).neg()
} else {
halo2::pasta::pallas::Scalar::from(a.0 as u64)
halo2::pasta::pallas::Scalar::from(
u64::try_from(a.0).expect("non-negative i64 fits in u64"),
)
}
}
}
impl<C> TryFrom<i32> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
C::validate(value.into()).map(|v| Self(v, PhantomData))
}
}
impl<C> TryFrom<i64> for Amount<C>
where
C: Constraint,
@ -209,17 +238,6 @@ where
}
}
impl<C> TryFrom<i32> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i32) -> Result<Self, Self::Error> {
C::validate(value as _).map(|v| Self(v, PhantomData))
}
}
impl<C> TryFrom<u64> for Amount<C>
where
C: Constraint,
@ -227,6 +245,25 @@ where
type Error = Error;
fn try_from(value: u64) -> Result<Self, Self::Error> {
let value = value.try_into().map_err(|source| Error::Convert {
value: value.into(),
source,
})?;
C::validate(value).map(|v| Self(v, PhantomData))
}
}
/// Conversion from `i128` to `Amount`.
///
/// Used to handle the result of multiplying negative `Amount`s by `u64`.
impl<C> TryFrom<i128> for Amount<C>
where
C: Constraint,
{
type Error = Error;
fn try_from(value: i128) -> Result<Self, Self::Error> {
let value = value
.try_into()
.map_err(|source| Error::Convert { value, source })?;
@ -260,8 +297,7 @@ impl<C> PartialEq<Amount<C>> for i64 {
}
}
impl Eq for Amount<NegativeAllowed> {}
impl Eq for Amount<NonNegative> {}
impl<C> Eq for Amount<C> {}
impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
@ -269,47 +305,54 @@ impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
}
}
impl Ord for Amount<NegativeAllowed> {
fn cmp(&self, other: &Amount<NegativeAllowed>) -> Ordering {
impl<C> Ord for Amount<C> {
fn cmp(&self, other: &Amount<C>) -> Ordering {
self.0.cmp(&other.0)
}
}
impl Ord for Amount<NonNegative> {
fn cmp(&self, other: &Amount<NonNegative>) -> Ordering {
self.0.cmp(&other.0)
}
}
impl std::ops::Mul<u64> for Amount<NonNegative> {
type Output = Result<Amount<NonNegative>>;
impl<C> std::ops::Mul<u64> for Amount<C>
where
C: Constraint,
{
type Output = Result<Amount<C>>;
fn mul(self, rhs: u64) -> Self::Output {
let value = (self.0 as u64)
.checked_mul(rhs)
.ok_or(Error::MultiplicationOverflow {
amount: self.0,
multiplier: rhs,
})?;
value.try_into()
// use i128 for multiplication, so we can handle negative Amounts
let value = i128::from(self.0)
.checked_mul(i128::from(rhs))
.expect("multiplying i64 by u64 can't overflow i128");
value.try_into().map_err(|_| Error::MultiplicationOverflow {
amount: self.0,
multiplier: rhs,
overflowing_result: value,
})
}
}
impl std::ops::Mul<Amount<NonNegative>> for u64 {
type Output = Result<Amount<NonNegative>>;
impl<C> std::ops::Mul<Amount<C>> for u64
where
C: Constraint,
{
type Output = Result<Amount<C>>;
fn mul(self, rhs: Amount<NonNegative>) -> Self::Output {
fn mul(self, rhs: Amount<C>) -> Self::Output {
rhs.mul(self)
}
}
impl std::ops::Div<u64> for Amount<NonNegative> {
type Output = Result<Amount<NonNegative>>;
impl<C> std::ops::Div<u64> for Amount<C>
where
C: Constraint,
{
type Output = Result<Amount<C>>;
fn div(self, rhs: u64) -> Self::Output {
let quotient = (self.0 as u64)
.checked_div(rhs)
let quotient = i128::from(self.0)
.checked_div(i128::from(rhs))
.ok_or(Error::DivideByZero { amount: self.0 })?;
Ok(quotient
.try_into()
.expect("division by a positive integer always stays within the constraint"))
@ -320,14 +363,16 @@ impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
where
C: Constraint,
{
fn sum<I: Iterator<Item = Amount<C>>>(iter: I) -> Self {
let sum = iter
.map(|a| a.0)
.try_fold(0i64, |acc, amount| acc.checked_add(amount));
fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
match sum {
Some(sum) => Amount::try_from(sum),
None => Err(Error::SumOverflow),
Ok(sum) => Ok(sum),
Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
partial_sum: value,
remaining_items: iter.count(),
}),
Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
}
}
}
@ -352,27 +397,37 @@ where
}
}
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)]
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)]
#[allow(missing_docs)]
/// Errors that can be returned when validating `Amount`s
pub enum Error {
/// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}
Contains {
range: RangeInclusive<i64>,
Constraint {
value: i64,
range: RangeInclusive<i64>,
},
/// u64 {value} could not be converted to an i64 Amount
/// {value} could not be converted to an i64 Amount
Convert {
value: u64,
value: i128,
source: std::num::TryFromIntError,
},
/// i64 overflow when multiplying i64 non-negative amount {amount} by u64 {multiplier}
MultiplicationOverflow { amount: i64, multiplier: u64 },
/// i64 overflow when multiplying i64 amount {amount} by u64 {multiplier}, overflowing result {overflowing_result}
MultiplicationOverflow {
amount: i64,
multiplier: u64,
overflowing_result: i128,
},
/// cannot divide amount {amount} by zero
DivideByZero { amount: i64 },
/// i64 overflow when summing i64 amounts
SumOverflow,
/// i64 overflow when summing i64 amounts, partial_sum: {partial_sum}, remaining items: {remaining_items}
SumOverflow {
partial_sum: i64,
remaining_items: usize,
},
}
/// Marker type for `Amount` that allows negative values.
@ -427,7 +482,7 @@ pub trait Constraint {
let range = Self::valid_range();
if !range.contains(&value) {
Err(Error::Contains { range, value })
Err(Error::Constraint { value, range })
} else {
Ok(value)
}
@ -778,9 +833,9 @@ mod test {
assert_eq!(sum_ref, sum_value);
assert_eq!(
sum_ref,
Err(Error::Contains {
range: -MAX_MONEY..=MAX_MONEY,
value: integer_sum,
Err(Error::SumOverflow {
partial_sum: integer_sum,
remaining_items: 0
})
);
@ -795,9 +850,9 @@ mod test {
assert_eq!(sum_ref, sum_value);
assert_eq!(
sum_ref,
Err(Error::Contains {
range: -MAX_MONEY..=MAX_MONEY,
value: integer_sum,
Err(Error::SumOverflow {
partial_sum: integer_sum,
remaining_items: 0
})
);
@ -813,7 +868,13 @@ mod test {
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow));
assert_eq!(
sum_ref,
Err(Error::SumOverflow {
partial_sum: 4200000000000000,
remaining_items: 4391
})
);
// below min of i64 overflow
let times: usize = (i64::MAX / MAX_MONEY)
@ -827,7 +888,13 @@ mod test {
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow));
assert_eq!(
sum_ref,
Err(Error::SumOverflow {
partial_sum: -4200000000000000,
remaining_items: 4391
})
);
Ok(())
}

View File

@ -1,6 +1,6 @@
//! A type that can hold the four types of Zcash value pools.
use crate::amount::{Amount, Constraint, Error, NonNegative};
use crate::amount::{Amount, Constraint, Error, NegativeAllowed, NonNegative};
use std::convert::TryInto;
@ -10,8 +10,10 @@ mod arbitrary;
#[cfg(test)]
mod tests;
use ValueBalanceError::*;
/// An amount spread between different Zcash pools.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ValueBalance<C> {
transparent: Amount<C>,
sprout: Amount<C>,
@ -131,6 +133,20 @@ where
}
}
/// Convert this value balance to a different ValueBalance type,
/// if it satisfies the new constraint
pub fn constrain<C2>(self) -> Result<ValueBalance<C2>, ValueBalanceError>
where
C2: Constraint,
{
Ok(ValueBalance::<C2> {
transparent: self.transparent.constrain().map_err(Transparent)?,
sprout: self.sprout.constrain().map_err(Sprout)?,
sapling: self.sapling.constrain().map_err(Sapling)?,
orchard: self.orchard.constrain().map_err(Orchard)?,
})
}
/// To byte array
pub fn to_bytes(self) -> [u8; 32] {
let transparent = self.transparent.to_bytes();
@ -151,22 +167,29 @@ where
bytes[0..8]
.try_into()
.expect("Extracting the first quarter of a [u8; 32] should always succeed"),
)?;
)
.map_err(Transparent)?;
let sprout = Amount::from_bytes(
bytes[8..16]
.try_into()
.expect("Extracting the second quarter of a [u8; 32] should always succeed"),
)?;
)
.map_err(Sprout)?;
let sapling = Amount::from_bytes(
bytes[16..24]
.try_into()
.expect("Extracting the third quarter of a [u8; 32] should always succeed"),
)?;
)
.map_err(Sapling)?;
let orchard = Amount::from_bytes(
bytes[24..32]
.try_into()
.expect("Extracting the last quarter of a [u8; 32] should always succeed"),
)?;
)
.map_err(Orchard)?;
Ok(ValueBalance {
transparent,
@ -177,12 +200,20 @@ where
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq)]
/// Errors that can be returned when validating a [`ValueBalance`].
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq, Eq)]
/// Errors that can be returned when validating a [`ValueBalance`]
pub enum ValueBalanceError {
#[error("value balance contains invalid amounts")]
/// Any error related to [`Amount`]s inside the [`ValueBalance`]
AmountError(#[from] Error),
/// transparent amount error {0}
Transparent(Error),
/// sprout amount error {0}
Sprout(Error),
/// sapling amount error {0}
Sapling(Error),
/// orchard amount error {0}
Orchard(Error),
}
impl<C> std::ops::Add for ValueBalance<C>
@ -192,10 +223,10 @@ where
type Output = Result<ValueBalance<C>, ValueBalanceError>;
fn add(self, rhs: ValueBalance<C>) -> Self::Output {
Ok(ValueBalance::<C> {
transparent: (self.transparent + rhs.transparent)?,
sprout: (self.sprout + rhs.sprout)?,
sapling: (self.sapling + rhs.sapling)?,
orchard: (self.orchard + rhs.orchard)?,
transparent: (self.transparent + rhs.transparent).map_err(Transparent)?,
sprout: (self.sprout + rhs.sprout).map_err(Sprout)?,
sapling: (self.sapling + rhs.sapling).map_err(Sapling)?,
orchard: (self.orchard + rhs.orchard).map_err(Orchard)?,
})
}
}
@ -216,10 +247,10 @@ where
type Output = Result<ValueBalance<C>, ValueBalanceError>;
fn sub(self, rhs: ValueBalance<C>) -> Self::Output {
Ok(ValueBalance::<C> {
transparent: (self.transparent - rhs.transparent)?,
sprout: (self.sprout - rhs.sprout)?,
sapling: (self.sapling - rhs.sapling)?,
orchard: (self.orchard - rhs.orchard)?,
transparent: (self.transparent - rhs.transparent).map_err(Transparent)?,
sprout: (self.sprout - rhs.sprout).map_err(Sprout)?,
sapling: (self.sapling - rhs.sapling).map_err(Sapling)?,
orchard: (self.orchard - rhs.orchard).map_err(Orchard)?,
})
}
}
@ -239,12 +270,7 @@ where
{
fn sum<I: Iterator<Item = ValueBalance<C>>>(mut iter: I) -> Self {
iter.try_fold(ValueBalance::zero(), |acc, value_balance| {
Ok(ValueBalance {
transparent: (acc.transparent + value_balance.transparent)?,
sprout: (acc.sprout + value_balance.sprout)?,
sapling: (acc.sapling + value_balance.sapling)?,
orchard: (acc.orchard + value_balance.orchard)?,
})
acc + value_balance
})
}
}
@ -257,3 +283,19 @@ where
iter.copied().sum()
}
}
impl<C> std::ops::Neg for ValueBalance<C>
where
C: Constraint,
{
type Output = ValueBalance<NegativeAllowed>;
fn neg(self) -> Self::Output {
ValueBalance::<NegativeAllowed> {
transparent: self.transparent.neg(),
sprout: self.sprout.neg(),
sapling: self.sapling.neg(),
orchard: self.orchard.neg(),
}
}
}

View File

@ -29,7 +29,11 @@ proptest! {
),
_ => prop_assert!(
matches!(
value_balance1 + value_balance2, Err(ValueBalanceError::AmountError(_))
value_balance1 + value_balance2,
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
)
),
}
@ -58,7 +62,11 @@ proptest! {
),
_ => prop_assert!(
matches!(
value_balance1 - value_balance2, Err(ValueBalanceError::AmountError(_))
value_balance1 - value_balance2,
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
)
),
}
@ -88,7 +96,12 @@ proptest! {
orchard,
})
),
_ => prop_assert!(matches!(collection.iter().sum(), Err(ValueBalanceError::AmountError(_))))
_ => prop_assert!(matches!(collection.iter().sum(),
Err(ValueBalanceError::Transparent(_)
| ValueBalanceError::Sprout(_)
| ValueBalanceError::Sapling(_)
| ValueBalanceError::Orchard(_))
))
}
}