Skip to content

Commit 4c6f87d

Browse files
Add support for exact division (#1256)
This adds `div_exact` and `div_exact_vartime` methods for `Uint`, `BoxedUint`, and `Limb`. There is currently no trait support. For cases where it applies, exact division can be much faster than div/rem. In my tests this is about 40% faster for a 4096-bit `BoxedUint`. The algorithm is similar to `DivideByWord` in Modern Computer Arithmetic (https://arxiv.org/pdf/1004.4710) but I don't have a reference for the adjustments to support a multi-word divisor. --------- Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
1 parent 9bda2fd commit 4c6f87d

15 files changed

Lines changed: 508 additions & 52 deletions

File tree

benches/boxed_uint.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,42 @@ fn bench_division(c: &mut Criterion) {
232232
);
233233
});
234234

235+
group.bench_function("boxed_div_exact", |b| {
236+
b.iter_batched(
237+
|| {
238+
(
239+
BoxedUint::max(UINT_BITS),
240+
NonZero::new(BoxedUint::random_bits_with_precision(
241+
&mut rng,
242+
UINT_BITS / 2,
243+
UINT_BITS,
244+
))
245+
.unwrap(),
246+
)
247+
},
248+
|(x, y)| black_box(x.div_exact(&y)),
249+
BatchSize::SmallInput,
250+
);
251+
});
252+
253+
group.bench_function("boxed_div_exact_vartime", |b| {
254+
b.iter_batched(
255+
|| {
256+
(
257+
BoxedUint::max(UINT_BITS),
258+
NonZero::new(BoxedUint::random_bits_with_precision(
259+
&mut rng,
260+
UINT_BITS / 2,
261+
UINT_BITS,
262+
))
263+
.unwrap(),
264+
)
265+
},
266+
|(x, y)| black_box(x.div_exact_vartime(&y)),
267+
BatchSize::SmallInput,
268+
);
269+
});
270+
235271
group.finish();
236272
}
237273

benches/uint.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,44 @@ fn bench_division(c: &mut Criterion) {
393393
);
394394
});
395395

396+
group.bench_function("div exact, U256/U128", |b| {
397+
b.iter_batched(
398+
|| {
399+
(
400+
U256::random_from_rng(&mut rng),
401+
NonZero::<U128>::random_from_rng(&mut rng),
402+
)
403+
},
404+
|(x, y)| x.div_exact(&y),
405+
BatchSize::SmallInput,
406+
);
407+
});
408+
409+
group.bench_function("div exact, U256/U128 (in U256)", |b| {
410+
b.iter_batched(
411+
|| {
412+
let x = U256::random_from_rng(&mut rng);
413+
let y_half = U128::random_from_rng(&mut rng);
414+
let y: U256 = (y_half, U128::ZERO).into();
415+
(x, NonZero::new(y).unwrap())
416+
},
417+
|(x, y)| x.div_exact(&y),
418+
BatchSize::SmallInput,
419+
);
420+
});
421+
422+
group.bench_function("div exact, U256/U128 (in U512)", |b| {
423+
b.iter_batched(
424+
|| {
425+
let x = U256::random_from_rng(&mut rng);
426+
let y: U512 = U128::random_from_rng(&mut rng).resize();
427+
(x, NonZero::new(y).unwrap())
428+
},
429+
|(x, y)| x.div_exact(&y),
430+
BatchSize::SmallInput,
431+
);
432+
});
433+
396434
group.bench_function("div/rem_vartime, U256/U128, full size", |b| {
397435
b.iter_batched(
398436
|| {
@@ -405,6 +443,18 @@ fn bench_division(c: &mut Criterion) {
405443
);
406444
});
407445

446+
group.bench_function("div exact vartime, U256/U128, full size", |b| {
447+
b.iter_batched(
448+
|| {
449+
let x = U256::random_from_rng(&mut rng);
450+
let y = U256::from((NonZero::<U128>::random_from_rng(&mut rng).get(), U128::ZERO));
451+
(x, NonZero::new(y).unwrap())
452+
},
453+
|(x, y)| x.div_exact_vartime(&y),
454+
BatchSize::SmallInput,
455+
);
456+
});
457+
408458
group.bench_function("rem, U256/U128", |b| {
409459
b.iter_batched(
410460
|| {
@@ -510,6 +560,19 @@ fn bench_division(c: &mut Criterion) {
510560
);
511561
});
512562

563+
group.bench_function("div exact vartime, U256/Limb, full size", |b| {
564+
b.iter_batched(
565+
|| {
566+
let x = U256::random_from_rng(&mut rng);
567+
let y_small = Limb::random_from_rng(&mut rng);
568+
let y = U256::from_word(y_small.0);
569+
(x, NonZero::new(y).unwrap())
570+
},
571+
|(x, y)| x.div_exact_vartime(&y),
572+
BatchSize::SmallInput,
573+
);
574+
});
575+
513576
group.bench_function("div/rem, U256/Limb, single limb", |b| {
514577
b.iter_batched(
515578
|| {

src/limb.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod div;
1313
mod encoding;
1414
mod from;
1515
mod gcd;
16+
mod invert_mod;
1617
mod mul;
1718
mod neg;
1819
mod shl;

src/limb/div.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,30 @@ impl Limb {
4242
let rem = self.div_rem(rhs_nz).1;
4343
CtOption::new(rem, is_nz)
4444
}
45+
46+
/// Exactly divides `self` by `rhs`, returning `CtOption::none()` if `self` is not divisible by `rhs`.
47+
#[must_use]
48+
pub const fn div_exact(&self, rhs: NonZero<Limb>) -> CtOption<Self> {
49+
let mut quo = *self;
50+
let mut div = rhs.get_copy();
51+
let exact = UintRef::new_mut(slice::from_mut(&mut quo))
52+
.div_exact(UintRef::new_mut(slice::from_mut(&mut div)));
53+
CtOption::new(quo, exact)
54+
}
55+
56+
/// Exactly divides `self` by `rhs`, returning `CtOption::none()` if `self` is not divisible by `rhs`.
57+
///
58+
/// This is variable-time only with respect to `rhs`.
59+
///
60+
/// When used with a fixed `rhs`, this function is constant-time with respect to `self`.
61+
#[must_use]
62+
pub const fn div_exact_vartime(&self, rhs: NonZero<Limb>) -> CtOption<Self> {
63+
let mut quo = *self;
64+
let mut div = rhs.get_copy();
65+
let exact = UintRef::new_mut(slice::from_mut(&mut quo))
66+
.div_exact_vartime(UintRef::new_mut(slice::from_mut(&mut div)));
67+
CtOption::new(quo, exact)
68+
}
4569
}
4670

4771
impl CheckedDiv for Limb {
@@ -285,6 +309,17 @@ mod tests {
285309
let n = Limb::from_u32(0xffff_ffff);
286310
let d = NonZero::new(Limb::from_u32(0xfffe)).expect("ensured non-zero");
287311
assert_eq!(n.div_rem(d), (Limb::from_u32(0x10002), Limb::from_u32(0x3)));
312+
313+
assert_eq!(n.div_exact(d).into_option(), None);
314+
assert_eq!(n.div_exact_vartime(d).into_option(), None);
315+
316+
let d = NonZero::new(Limb::from_u32(0xffff)).expect("ensured non-zero");
317+
assert_eq!(n.div_rem(d), (Limb::from_u32(0x10001), Limb::from_u32(0)));
318+
assert_eq!(n.div_exact(d).into_option(), Some(Limb::from_u32(0x10001)));
319+
assert_eq!(
320+
n.div_exact_vartime(d).into_option(),
321+
Some(Limb::from_u32(0x10001))
322+
);
288323
}
289324

290325
#[test]

src/limb/invert_mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use super::Limb;
2+
use crate::{Odd, primitives};
3+
4+
impl Odd<Limb> {
5+
/// Returns the multiplicative inverse of the argument modulo 2^N, where
6+
/// 2^N is the capacity of a [`Limb`].
7+
pub(crate) const fn multiplicative_inverse(self) -> Limb {
8+
cpubits::cpubits! {
9+
32 => {
10+
Limb(primitives::u32_invert_odd(self.as_ref().0))
11+
}
12+
64 => {
13+
Limb(primitives::u64_invert_odd(self.as_ref().0))
14+
}
15+
}
16+
}
17+
}

src/primitives.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,62 @@ pub(crate) const fn u32_bits(n: u32) -> u32 {
9494
u32::BITS - n.leading_zeros()
9595
}
9696

97+
/// Return a `Choice` representing whether `a < b`.
98+
#[allow(clippy::cast_possible_truncation)]
99+
#[cfg(target_pointer_width = "32")]
100+
#[inline]
101+
pub(crate) const fn usize_lt(a: usize, b: usize) -> Choice {
102+
Choice::from_u32_lt(a as u32, b as u32)
103+
}
104+
105+
/// Return a `Choice` representing whether `a < b`.
106+
#[allow(clippy::cast_possible_truncation)]
107+
#[cfg(target_pointer_width = "64")]
108+
#[inline]
109+
pub(crate) const fn usize_lt(a: usize, b: usize) -> Choice {
110+
Choice::from_u64_lt(a as u64, b as u64)
111+
}
112+
113+
cpubits::cpubits! {
114+
32 => {
115+
/// Returns the multiplicative inverse of the argument modulo 2^32.
116+
///
117+
/// For correct results, the input `value` must be odd.
118+
#[must_use]
119+
pub(crate) const fn u32_invert_odd(value: u32) -> u32 {
120+
debug_assert!(value & 1 == 1, "value must be odd");
121+
let x = value.wrapping_mul(3) ^ 2;
122+
let y = 1u32.wrapping_sub(x.wrapping_mul(value));
123+
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
124+
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
125+
x.wrapping_mul(y.wrapping_add(1))
126+
}
127+
}
128+
}
129+
130+
/// Returns the multiplicative inverse of the argument modulo 2^64. The implementation is based
131+
/// on Hurchalla's method for computing the multiplicative inverse modulo a power of two, and
132+
/// is essentially an optimized Newton iteration.
133+
///
134+
/// For correct results, the input `value` must be odd.
135+
///
136+
/// For better understanding the implementation, the following paper is recommended:
137+
/// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)",
138+
/// <https://arxiv.org/abs/2204.04342>
139+
#[must_use]
140+
pub(crate) const fn u64_invert_odd(value: u64) -> u64 {
141+
debug_assert!(value & 1 == 1, "value must be odd");
142+
let x = value.wrapping_mul(3) ^ 2;
143+
let y = 1u64.wrapping_sub(x.wrapping_mul(value));
144+
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
145+
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
146+
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
147+
x.wrapping_mul(y.wrapping_add(1))
148+
}
149+
97150
#[cfg(test)]
98151
mod tests {
99-
use super::{u32_max, u32_min, u32_rem};
152+
use super::{u32_max, u32_min, u32_rem, usize_lt};
100153
use crate::Word;
101154

102155
#[test]
@@ -133,4 +186,25 @@ mod tests {
133186
assert_eq!(u32_rem(7, 5), 2);
134187
assert_eq!(u32_rem(101, 5), 1);
135188
}
189+
190+
#[test]
191+
fn test_usize_const_lt() {
192+
assert!(usize_lt(0, 5).to_bool_vartime());
193+
assert!(!usize_lt(7, 0).to_bool_vartime());
194+
assert!(!usize_lt(7, 5).to_bool_vartime());
195+
assert!(!usize_lt(7, 7).to_bool_vartime());
196+
}
197+
198+
cpubits::cpubits! {
199+
32 => {
200+
#[test]
201+
fn test_u32_invert_odd() {
202+
use super::u32_invert_odd;
203+
204+
assert_eq!(u32_invert_odd(1), 1);
205+
assert_eq!(u32_invert_odd(5).wrapping_mul(5), 1);
206+
assert_eq!(u32_invert_odd(u32::MAX).wrapping_mul(u32::MAX), 1);
207+
}
208+
}
209+
}
136210
}

src/uint/boxed/div.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ impl BoxedUint {
6363
rem
6464
}
6565

66+
/// Exactly divides `self` by `rhs`, returning `CtOption::none()` if `self` is not divisible by `rhs`.
67+
#[must_use]
68+
pub fn div_exact<Rhs: ToUnsigned + ?Sized>(&self, rhs: &NonZero<Rhs>) -> CtOption<Self> {
69+
let mut quo = self.clone();
70+
let mut div = rhs.to_unsigned().get();
71+
let exact = quo.as_mut_uint_ref().div_exact(div.as_mut_uint_ref());
72+
CtOption::new(quo, exact)
73+
}
74+
6675
/// Computes self / rhs, returns the quotient and remainder.
6776
///
6877
/// Variable-time with respect to `rhs`
@@ -121,6 +130,20 @@ impl BoxedUint {
121130
rem
122131
}
123132

133+
/// Exactly divides `self` by `rhs`, returning `CtOption::none()` if `self` is not divisible by `rhs`.
134+
#[must_use]
135+
pub fn div_exact_vartime<Rhs: ToUnsigned + ?Sized>(
136+
&self,
137+
rhs: &NonZero<Rhs>,
138+
) -> CtOption<Self> {
139+
let mut quo = self.clone();
140+
let mut div = rhs.to_unsigned().get();
141+
let exact = quo
142+
.as_mut_uint_ref()
143+
.div_exact_vartime(div.as_mut_uint_ref());
144+
CtOption::new(quo, exact)
145+
}
146+
124147
/// Wrapped division is just normal division i.e. `self` / `rhs`
125148
/// There’s no way wrapping could ever happen.
126149
///

src/uint/boxed/lcm.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ impl Lcm for BoxedUint {
88
fn lcm(&self, rhs: &Self) -> Self {
99
let (lhs_nz, _) = self.to_nz_or_one();
1010
let gcd_nz = lhs_nz.gcd(rhs);
11-
self.wrapping_div(&gcd_nz).concatenating_mul(rhs)
11+
self.div_exact(&gcd_nz)
12+
.expect("invalid gcd")
13+
.concatenating_mul(rhs)
1214
}
1315

1416
fn lcm_vartime(&self, rhs: &Self) -> Self {
1517
let (Some(lhs_nz), false) = (self.as_nz_vartime(), rhs.is_zero_vartime()) else {
1618
return BoxedUint::zero_with_precision(self.bits_precision() + rhs.bits_precision());
1719
};
1820
let gcd_nz = lhs_nz.gcd_vartime(rhs);
19-
self.wrapping_div_vartime(&gcd_nz).concatenating_mul(rhs)
21+
self.div_exact_vartime(&gcd_nz)
22+
.expect("invalid gcd")
23+
.concatenating_mul(rhs)
2024
}
2125
}
2226

0 commit comments

Comments
 (0)