Skip to content

Commit 42c1384

Browse files
authored
Merge pull request #133 from TonalidadeHidrica/dev/floor_sum
Relax the constraints of `floor_sum` (revised)
2 parents 1954461 + 46caad4 commit 42c1384

File tree

2 files changed

+85
-23
lines changed

2 files changed

+85
-23
lines changed

src/internal_math.rs

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// remove this after dependencies has been added
22
#![allow(dead_code)]
3-
use std::mem::swap;
3+
use std::{mem::swap, num::Wrapping as W};
44

55
/// # Arguments
66
/// * `m` `1 <= m`
@@ -235,6 +235,46 @@ pub(crate) fn primitive_root(m: i32) -> i32 {
235235
// omitted
236236
// template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
237237

238+
/// # Arguments
239+
/// * `n` `n < 2^32`
240+
/// * `m` `1 <= m < 2^32`
241+
///
242+
/// # Returns
243+
/// `sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)`
244+
/* const */
245+
#[allow(clippy::many_single_char_names)]
246+
pub(crate) fn floor_sum_unsigned(
247+
mut n: W<u64>,
248+
mut m: W<u64>,
249+
mut a: W<u64>,
250+
mut b: W<u64>,
251+
) -> W<u64> {
252+
let mut ans = W(0);
253+
loop {
254+
if a >= m {
255+
if n > W(0) {
256+
ans += n * (n - W(1)) / W(2) * (a / m);
257+
}
258+
a %= m;
259+
}
260+
if b >= m {
261+
ans += n * (b / m);
262+
b %= m;
263+
}
264+
265+
let y_max = a * n + b;
266+
if y_max < m {
267+
break;
268+
}
269+
// y_max < m * (n + 1)
270+
// floor(y_max / m) <= n
271+
n = y_max / m;
272+
b = y_max % m;
273+
std::mem::swap(&mut m, &mut a);
274+
}
275+
ans
276+
}
277+
238278
#[cfg(test)]
239279
mod tests {
240280
#![allow(clippy::unreadable_literal)]

src/math.rs

+44-22
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,24 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
162162
(r0, m0)
163163
}
164164

165-
/// Returns $\sum_{i = 0}^{n - 1} \lfloor \frac{a \times i + b}{m} \rfloor$.
165+
/// Returns
166+
///
167+
/// $$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor.$$
168+
///
169+
/// It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
166170
///
167171
/// # Constraints
168172
///
169-
/// - $0 \leq n \leq 10^9$
170-
/// - $1 \leq m \leq 10^9$
171-
/// - $0 \leq a, b \leq m$
173+
/// - $0 \leq n \lt 2^{32}$
174+
/// - $1 \leq m \lt 2^{32}$
172175
///
173176
/// # Panics
174177
///
175178
/// Panics if the above constraints are not satisfied and overflow or division by zero occurred.
176179
///
177180
/// # Complexity
178181
///
179-
/// - $O(\log(n + m + a + b))$
182+
/// - $O(\log{(m+a)})$
180183
///
181184
/// # Example
182185
///
@@ -185,25 +188,25 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
185188
///
186189
/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
187190
/// ```
188-
pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
189-
let mut ans = 0;
190-
if a >= m {
191-
ans += (n - 1) * n * (a / m) / 2;
192-
a %= m;
193-
}
194-
if b >= m {
195-
ans += n * (b / m);
196-
b %= m;
191+
#[allow(clippy::many_single_char_names)]
192+
pub fn floor_sum(n: i64, m: i64, a: i64, b: i64) -> i64 {
193+
use std::num::Wrapping as W;
194+
assert!((0..1i64 << 32).contains(&n));
195+
assert!((1..1i64 << 32).contains(&m));
196+
let mut ans = W(0_u64);
197+
let (wn, wm, mut wa, mut wb) = (W(n as u64), W(m as u64), W(a as u64), W(b as u64));
198+
if a < 0 {
199+
let a2 = W(internal_math::safe_mod(a, m) as u64);
200+
ans -= wn * (wn - W(1)) / W(2) * ((a2 - wa) / wm);
201+
wa = a2;
197202
}
198-
199-
let y_max = (a * n + b) / m;
200-
let x_max = y_max * m - b;
201-
if y_max == 0 {
202-
return ans;
203+
if b < 0 {
204+
let b2 = W(internal_math::safe_mod(b, m) as u64);
205+
ans -= wn * ((b2 - wb) / wm);
206+
wb = b2;
203207
}
204-
ans += (n - (x_max + a - 1) / a) * y_max;
205-
ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
206-
ans
208+
let ret = ans + internal_math::floor_sum_unsigned(wn, wm, wa, wb);
209+
ret.0 as i64
207210
}
208211

209212
#[cfg(test)]
@@ -306,5 +309,24 @@ mod tests {
306309
499_999_999_500_000_000
307310
);
308311
assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575);
312+
for n in 0..20 {
313+
for m in 1..20 {
314+
for a in -20..20 {
315+
for b in -20..20 {
316+
assert_eq!(floor_sum(n, m, a, b), floor_sum_naive(n, m, a, b));
317+
}
318+
}
319+
}
320+
}
321+
}
322+
323+
#[allow(clippy::many_single_char_names)]
324+
fn floor_sum_naive(n: i64, m: i64, a: i64, b: i64) -> i64 {
325+
let mut ans = 0;
326+
for i in 0..n {
327+
let z = a * i + b;
328+
ans += (z - internal_math::safe_mod(z, m)) / m;
329+
}
330+
ans
309331
}
310332
}

0 commit comments

Comments
 (0)