diff --git a/src/cheats.rs b/src/cheats.rs index c54b3a2..618ff69 100644 --- a/src/cheats.rs +++ b/src/cheats.rs @@ -4,6 +4,7 @@ pub trait Cheats { const TWO: Self; const TEN: Self; const SCALING_FACTOR: Self; + const TWO_SCALING_FACTOR: Self; } macro_rules! impl_primitive { @@ -13,6 +14,7 @@ macro_rules! impl_primitive { const TEN: Self = 10; paste! { const SCALING_FACTOR: Self = [<10 $primitive>].pow(D as u32); + const TWO_SCALING_FACTOR: Self = 2 * [<10 $primitive>].pow(D as u32); } } }; diff --git a/src/decimal.rs b/src/decimal.rs index e750ea7..492f3bd 100644 --- a/src/decimal.rs +++ b/src/decimal.rs @@ -1,5 +1,5 @@ use std::cmp::Ordering; -use std::ops::{Add, Div, Mul, Neg, Sub}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::integer::{Integer, SignedInteger}; @@ -14,6 +14,7 @@ where { pub const ZERO: Decimal = Decimal(I::ZERO); pub const ONE: Decimal = Decimal(I::SCALING_FACTOR); + pub const TWO: Decimal = Decimal(I::TWO_SCALING_FACTOR); pub const DECIMALS: u8 = D; pub const SCALING_FACTOR: I = I::SCALING_FACTOR; @@ -101,77 +102,181 @@ where } } +impl AddAssign for Decimal +where + I: Integer, +{ + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = Decimal(self.0.checked_add(&rhs.0).unwrap()); + } +} + +impl SubAssign for Decimal +where + I: Integer, +{ + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = Decimal(self.0.checked_sub(&rhs.0).unwrap()); + } +} + +impl MulAssign for Decimal +where + I: Integer, +{ + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = Decimal(I::full_mul_div(self.0, rhs.0, I::SCALING_FACTOR)); + } +} + +impl DivAssign for Decimal +where + I: Integer, +{ + #[inline] + fn div_assign(&mut self, rhs: Self) { + *self = Decimal(I::full_mul_div(self.0, I::SCALING_FACTOR, rhs.0)); + } +} + #[cfg(test)] mod tests { use std::ops::Shr; + use malachite::num::basic::traits::Zero; + use malachite::Rational; use paste::paste; use proptest::prelude::*; use super::*; - use crate::{Int128_18, Int64_9, Uint128_18, Uint64_9}; + + macro_rules! apply_to_common_variants { + ($macro:ident) => { + $macro!(u8, 1); + $macro!(i8, 1); + $macro!(u16, 2); + $macro!(i16, 2); + $macro!(u32, 5); + $macro!(i32, 5); + $macro!(u64, 9); + $macro!(i64, 9); + $macro!(u128, 18); + $macro!(i128, 18); + }; + } macro_rules! test_basic_ops { - ($variant:ty) => { + ($underlying:ty, $decimals:literal) => { paste! { #[test] - fn [<$variant:lower _add>]() { + fn [<$underlying _ $decimals _add>]() { assert_eq!( - $variant::ONE + $variant::ONE, - Decimal($variant::SCALING_FACTOR * 2), + Decimal::<$underlying, $decimals>::ONE + Decimal::ONE, + Decimal::TWO, ); } #[test] - fn [<$variant:lower _sub>]() { - assert_eq!($variant::ONE - $variant::ONE, Decimal(0)); + fn [<$underlying _ $decimals _sub>]() { + assert_eq!( + Decimal::<$underlying, $decimals>::ONE - Decimal::ONE, + Decimal::ZERO, + ) } #[test] - fn [<$variant:lower _mul>]() { - assert_eq!($variant::ONE * $variant::ONE, $variant::ONE); + fn [<$underlying _ $decimals _mul>]() { + assert_eq!( + Decimal::<$underlying, $decimals>::ONE * Decimal::ONE, + Decimal::ONE, + ); } #[test] - fn [<$variant:lower _div>]() { - assert_eq!($variant::ONE / $variant::ONE, $variant::ONE); + fn [<$underlying _ $decimals _div>]() { + assert_eq!( + Decimal::<$underlying, $decimals>::ONE / Decimal::ONE, + Decimal::ONE, + ); } #[test] - fn [<$variant:lower _mul_min_by_one>]() { - assert_eq!($variant::min() * $variant::ONE, $variant::min()); + fn [<$underlying _ $decimals _mul_min_by_one>]() { + assert_eq!( + Decimal::min() * Decimal::<$underlying, $decimals>::ONE, + Decimal::min() + ); } #[test] - fn [<$variant:lower _div_min_by_one>]() { - assert_eq!($variant::min() / $variant::ONE, $variant::min()); + fn [<$underlying _ $decimals _div_min_by_one>]() { + assert_eq!( + Decimal::min() / Decimal::<$underlying, $decimals>::ONE, + Decimal::min() + ); } #[test] - fn [<$variant:lower _mul_max_by_one>]() { - assert_eq!($variant::max() * $variant::ONE, $variant::max()); + fn [<$underlying _ $decimals _mul_max_by_one>]() { + assert_eq!( + Decimal::max() * Decimal::<$underlying, $decimals>::ONE, + Decimal::max(), + ); } #[test] - fn [<$variant:lower _div_max_by_one>]() { - assert_eq!($variant::max() / $variant::ONE, $variant::max()); + fn [<$underlying _ $decimals _div_max_by_one>]() { + assert_eq!( + Decimal::max() / Decimal::<$underlying, $decimals>::ONE, + Decimal::max(), + ); + } + + #[test] + fn [<$underlying _ $decimals _add_assign>]() { + let mut out = Decimal::<$underlying, $decimals>::ONE; + out += Decimal::ONE; + + assert_eq!(out, Decimal::ONE + Decimal::ONE); + } + + #[test] + fn [<$underlying _ $decimals _sub_assign>]() { + let mut out = Decimal::<$underlying, $decimals>::ONE; + out -= Decimal::<$underlying, $decimals>::ONE; + + assert_eq!(out, Decimal::ZERO); + } + + #[test] + fn [<$underlying _ $decimals _mul_assign>]() { + let mut out = Decimal::<$underlying, $decimals>::ONE; + out *= Decimal::TWO; + + assert_eq!(out, Decimal::ONE + Decimal::ONE); + } + + #[test] + fn [<$underlying _ $decimals _div_assign>]() { + let mut out = Decimal::<$underlying, $decimals>::ONE; + out /= Decimal::TWO; + + assert_eq!(out, Decimal::ONE / Decimal::TWO); } } }; } - test_basic_ops!(Uint64_9); - test_basic_ops!(Uint128_18); - test_basic_ops!(Int64_9); - test_basic_ops!(Int128_18); - macro_rules! fuzz_against_primitive { ($primitive:tt, $decimals:literal) => { paste! { proptest! { /// Addition functions the same as regular unsigned integer addition. #[test] - fn [<$primitive _ $decimals _add>]( + fn []( x in $primitive::MIN..$primitive::MAX, y in $primitive::MIN..$primitive::MAX, ) { @@ -191,7 +296,7 @@ mod tests { /// Subtraction functions the same as regular unsigned integer addition. #[test] - fn [<$primitive _ $decimals _sub>]( + fn []( x in $primitive::MIN..$primitive::MAX, y in $primitive::MIN..$primitive::MAX, ) { @@ -211,7 +316,7 @@ mod tests { /// Multiplication requires the result to be divided by the scaling factor. #[test] - fn [<$primitive _ $decimals _mul>]( + fn []( x in ($primitive::MIN.shr($primitive::BITS / 2)) ..($primitive::MAX.shr($primitive::BITS / 2)), y in ($primitive::MIN.shr($primitive::BITS / 2)) @@ -239,7 +344,7 @@ mod tests { /// Division requires the numerator to first be scaled by the scaling factor. #[test] - fn [<$primitive _ $decimals _div>]( + fn []( x in ($primitive::MIN / $primitive::pow(10, $decimals)) ..($primitive::MAX / $primitive::pow(10, $decimals)), y in ($primitive::MIN / $primitive::pow(10, $decimals)) @@ -269,14 +374,237 @@ mod tests { }; } - fuzz_against_primitive!(u8, 1); - fuzz_against_primitive!(i8, 1); - fuzz_against_primitive!(u16, 2); - fuzz_against_primitive!(i16, 2); - fuzz_against_primitive!(u32, 5); - fuzz_against_primitive!(i32, 5); - fuzz_against_primitive!(u64, 9); - fuzz_against_primitive!(i64, 9); - fuzz_against_primitive!(u128, 18); - fuzz_against_primitive!(i128, 18); + macro_rules! differential_fuzz { + ($underlying:ty, $decimals:literal) => { + paste! { + #[test] + fn []() { + differential_fuzz_add::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_sub::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_mul::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_div::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_add_assign::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_sub_assign::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_mul_assign::<$underlying, $decimals>(); + } + + #[test] + fn []() { + differential_fuzz_div_assign::<$underlying, $decimals>(); + } + } + }; + } + + fn differential_fuzz_add() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| a + b) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) + Rational::from(b); + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_sub() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| a - b) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) - Rational::from(b); + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_mul() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe + Into, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| a * b) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) * Rational::from(b); + + // If the multiplication contains truncation ignore it. + let scaling: malachite::Integer = Decimal::::SCALING_FACTOR.into(); + let divisor = malachite::Integer::from(reference_out.denominator_ref()); + if scaling % divisor != malachite::Integer::ZERO { + // TODO: Can we assert they are within N of each other? + return Ok(()); + } + + assert_eq!(Rational::from(out), reference_out, "{} {a:?} {b:?} {out:?} {reference_out:?}", I::SCALING_FACTOR); + }); + } + + fn differential_fuzz_div() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe + Into, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + if b == Decimal::ZERO { + return Ok(()); + } + + let out = match std::panic::catch_unwind(|| a / b) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) / Rational::from(b); + + // If the division contains truncation ignore it. + let scaling: malachite::Integer = Decimal::::SCALING_FACTOR.into(); + let divisor = malachite::Integer::from(reference_out.denominator_ref()); + if scaling % divisor != malachite::Integer::ZERO { + // TODO: Can we assert they are within N of each other? + return Ok(()); + } + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_add_assign() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| { + let mut out = a; + out += b; + + out + }) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) + Rational::from(b); + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_sub_assign() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| { + let mut out = a; + out -= b; + + out + }) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) - Rational::from(b); + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_mul_assign() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe + Into, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| { + let mut out = a; + out *= b; + + out + }) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) * Rational::from(b); + + // If the multiplication contains truncation ignore it. + let scaling: malachite::Integer = Decimal::::SCALING_FACTOR.into(); + let divisor = malachite::Integer::from(reference_out.denominator_ref()); + if scaling % divisor != malachite::Integer::ZERO { + // TODO: Can we assert they are within N of each other? + return Ok(()); + } + + assert_eq!(Rational::from(out), reference_out); + }); + } + + fn differential_fuzz_div_assign() + where + I: Integer + Arbitrary + std::panic::RefUnwindSafe + Into, + Rational: From>, + { + proptest!(|(a: Decimal, b: Decimal)| { + let out = match std::panic::catch_unwind(|| { + let mut out = a; + out /= b; + + out + }) { + Ok(out) => out, + Err(_) => return Ok(()), + }; + let reference_out = Rational::from(a) / Rational::from(b); + + // If the division contains truncation ignore it. + let scaling: malachite::Integer = Decimal::::SCALING_FACTOR.into(); + let divisor = malachite::Integer::from(reference_out.denominator_ref()); + if scaling % divisor != malachite::Integer::ZERO { + // TODO: Can we assert they are within N of each other? + return Ok(()); + } + + assert_eq!(Rational::from(out), reference_out); + }); + } + + apply_to_common_variants!(test_basic_ops); + apply_to_common_variants!(fuzz_against_primitive); + apply_to_common_variants!(differential_fuzz); } diff --git a/src/foreign_traits/malachite.rs b/src/foreign_traits/malachite.rs new file mode 100644 index 0000000..7fc3883 --- /dev/null +++ b/src/foreign_traits/malachite.rs @@ -0,0 +1,17 @@ +use malachite::num::basic::integers::PrimitiveInt; +use malachite::Rational; + +use crate::{Decimal, Integer}; + +impl From> for Rational +where + I: Integer + PrimitiveInt, + malachite::Integer: From, +{ + fn from(value: Decimal) -> Self { + Rational::from_integers( + malachite::Integer::from(value.0), + malachite::Integer::from(I::SCALING_FACTOR), + ) + } +} diff --git a/src/foreign_traits/mod.rs b/src/foreign_traits/mod.rs index fc14df3..662ce3d 100644 --- a/src/foreign_traits/mod.rs +++ b/src/foreign_traits/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] -mod arbitrary; +mod malachite; +#[cfg(test)] +mod proptest; diff --git a/src/foreign_traits/arbitrary.rs b/src/foreign_traits/proptest.rs similarity index 100% rename from src/foreign_traits/arbitrary.rs rename to src/foreign_traits/proptest.rs diff --git a/src/full_mul_div.rs b/src/full_mul_div.rs index 527d3d1..61c55c1 100644 --- a/src/full_mul_div.rs +++ b/src/full_mul_div.rs @@ -87,9 +87,6 @@ impl FullMulDiv for i128 { } } -// TODO: Fuzz test u128 & i128 full mul div implementations against a reference -// implementation. - #[cfg(test)] mod tests { use malachite::Integer; @@ -106,7 +103,6 @@ mod tests { // Compute reference value. let reference = Integer::from(a) * Integer::from(b) / Integer::from(div); - println!("{reference}"); // If the output fits in a u128 then ours should match. match u128::try_from(&reference) { @@ -125,7 +121,6 @@ mod tests { // Compute reference value. let reference = Integer::from(a) * Integer::from(b) / Integer::from(div); - println!("{reference}"); // If the output fits in an i128 then ours should match. match i128::try_from(&reference) {