@@ -162,21 +162,24 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
162
162
( r0, m0)
163
163
}
164
164
165
- /// Returns $\sum_{i = 0}^{n - 1} \lfloor \frac{a \times i + b}{m} \rfloor$.
165
+ /// Returns
166
+ ///
167
+ /// $$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor.$$
168
+ ///
169
+ /// It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
166
170
///
167
171
/// # Constraints
168
172
///
169
- /// - $0 \leq n \leq 10^9$
170
- /// - $1 \leq m \leq 10^9$
171
- /// - $0 \leq a, b \leq m$
173
+ /// - $0 \leq n \lt 2^{32}$
174
+ /// - $1 \leq m \lt 2^{32}$
172
175
///
173
176
/// # Panics
174
177
///
175
178
/// Panics if the above constraints are not satisfied and overflow or division by zero occurred.
176
179
///
177
180
/// # Complexity
178
181
///
179
- /// - $O(\log(n + m + a + b) )$
182
+ /// - $O(\log{(m+a)} )$
180
183
///
181
184
/// # Example
182
185
///
@@ -185,25 +188,25 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
185
188
///
186
189
/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
187
190
/// ```
188
- pub fn floor_sum ( n : i64 , m : i64 , mut a : i64 , mut b : i64 ) -> i64 {
189
- let mut ans = 0 ;
190
- if a >= m {
191
- ans += ( n - 1 ) * n * ( a / m) / 2 ;
192
- a %= m;
193
- }
194
- if b >= m {
195
- ans += n * ( b / m) ;
196
- b %= m;
191
+ #[ allow( clippy:: many_single_char_names) ]
192
+ pub fn floor_sum ( n : i64 , m : i64 , a : i64 , b : i64 ) -> i64 {
193
+ use std:: num:: Wrapping as W ;
194
+ assert ! ( ( 0 ..1i64 << 32 ) . contains( & n) ) ;
195
+ assert ! ( ( 1 ..1i64 << 32 ) . contains( & m) ) ;
196
+ let mut ans = W ( 0_u64 ) ;
197
+ let ( wn, wm, mut wa, mut wb) = ( W ( n as u64 ) , W ( m as u64 ) , W ( a as u64 ) , W ( b as u64 ) ) ;
198
+ if a < 0 {
199
+ let a2 = W ( internal_math:: safe_mod ( a, m) as u64 ) ;
200
+ ans -= wn * ( wn - W ( 1 ) ) / W ( 2 ) * ( ( a2 - wa) / wm) ;
201
+ wa = a2;
197
202
}
198
-
199
- let y_max = ( a * n + b) / m;
200
- let x_max = y_max * m - b;
201
- if y_max == 0 {
202
- return ans;
203
+ if b < 0 {
204
+ let b2 = W ( internal_math:: safe_mod ( b, m) as u64 ) ;
205
+ ans -= wn * ( ( b2 - wb) / wm) ;
206
+ wb = b2;
203
207
}
204
- ans += ( n - ( x_max + a - 1 ) / a) * y_max;
205
- ans += floor_sum ( y_max, a, m, ( a - x_max % a) % a) ;
206
- ans
208
+ let ret = ans + internal_math:: floor_sum_unsigned ( wn, wm, wa, wb) ;
209
+ ret. 0 as i64
207
210
}
208
211
209
212
#[ cfg( test) ]
@@ -306,5 +309,24 @@ mod tests {
306
309
499_999_999_500_000_000
307
310
) ;
308
311
assert_eq ! ( floor_sum( 332955 , 5590132 , 2231 , 999423 ) , 22014575 ) ;
312
+ for n in 0 ..20 {
313
+ for m in 1 ..20 {
314
+ for a in -20 ..20 {
315
+ for b in -20 ..20 {
316
+ assert_eq ! ( floor_sum( n, m, a, b) , floor_sum_naive( n, m, a, b) ) ;
317
+ }
318
+ }
319
+ }
320
+ }
321
+ }
322
+
323
+ #[ allow( clippy:: many_single_char_names) ]
324
+ fn floor_sum_naive ( n : i64 , m : i64 , a : i64 , b : i64 ) -> i64 {
325
+ let mut ans = 0 ;
326
+ for i in 0 ..n {
327
+ let z = a * i + b;
328
+ ans += ( z - internal_math:: safe_mod ( z, m) ) / m;
329
+ }
330
+ ans
309
331
}
310
332
}
0 commit comments