Skip to content

Commit 0cb2499

Browse files
committed
internal_math: mul_mod fix for 2^31<m<2^32
1 parent 0b92413 commit 0cb2499

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/internal_math.rs

+15-9
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,26 @@ impl Barrett {
5959
///
6060
/// * `a` `0 <= a < m`
6161
/// * `b` `0 <= b < m`
62-
/// * `m` `1 <= m <= 2^31`
63-
/// * `im` = ceil(2^64 / `m`)
62+
/// * `m` `1 <= m < 2^32`
63+
/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
6464
#[allow(clippy::many_single_char_names)]
6565
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
6666
// [1] m = 1
6767
// a = b = im = 0, so okay
6868

6969
// [2] m >= 2
70-
// im = ceil(2^64 / m)
70+
// im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1
7171
// -> im * m = 2^64 + r (0 <= r < m)
7272
// let z = a*b = c*m + d (0 <= c, d < m)
7373
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
7474
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
7575
// ((ab * im) >> 64) == c or c + 1
76-
let mut z = a as u64;
77-
z *= b as u64;
76+
let z = (a as u64) * (b as u64);
7877
let x = (((z as u128) * (im as u128)) >> 64) as u64;
79-
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
80-
if m <= v {
81-
v = v.wrapping_add(m);
78+
match z.overflowing_sub(x.wrapping_mul(m as u64)) {
79+
(v, true) => (v as u32).wrapping_add(m),
80+
(v, false) => v as u32,
8281
}
83-
v
8482
}
8583

8684
/// # Parameters
@@ -280,6 +278,14 @@ mod tests {
280278
let b = Barrett::new(2147483647);
281279
assert_eq!(b.umod(), 2147483647);
282280
assert_eq!(b.mul(1073741824, 2147483645), 2147483646);
281+
282+
// test `2^31 < self._m < 2^32` case.
283+
let b = Barrett::new(3221225471);
284+
assert_eq!(b.umod(), 3221225471);
285+
assert_eq!(b.mul(3188445886, 2844002853), 1840468257);
286+
assert_eq!(b.mul(2834869488, 2779159607), 2084027561);
287+
assert_eq!(b.mul(3032263594, 3039996727), 2130247251);
288+
assert_eq!(b.mul(3029175553, 3140869278), 1892378237);
283289
}
284290

285291
#[test]

0 commit comments

Comments
 (0)