Skip to content

Commit aa3beb7

Browse files
Allow RangeBounds for fenwick and segtrees
1 parent d2b35ac commit aa3beb7

File tree

6 files changed

+127
-22
lines changed

6 files changed

+127
-22
lines changed

examples/library-checker-static-range-sum.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ fn main() {
2020
fenwick.add(i, a);
2121
}
2222
for (l, r) in lrs {
23-
println!("{}", fenwick.sum(l, r));
23+
println!("{}", fenwick.sum(l..r));
2424
}
2525
}

examples/practice2_j_segment_tree.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ fn main() {
2020
segtree.set(x, v);
2121
}
2222
2 => {
23-
let l = input.next().unwrap().parse().unwrap();
23+
let l: usize = input.next().unwrap().parse().unwrap();
2424
let r: usize = input.next().unwrap().parse().unwrap();
25-
println!("{}", segtree.prod(l, r + 1));
25+
println!("{}", segtree.prod(l..=r));
2626
}
2727
3 => {
2828
let x = input.next().unwrap().parse().unwrap();

examples/practice2_k_range_affine_range_sum.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ fn main() {
5757
segtree.apply_range(l, r, (b, c));
5858
}
5959
1 => {
60-
let l = input.next().unwrap().parse().unwrap();
61-
let r = input.next().unwrap().parse().unwrap();
62-
println!("{}", segtree.prod(l, r).0);
60+
let l: usize = input.next().unwrap().parse().unwrap();
61+
let r: usize = input.next().unwrap().parse().unwrap();
62+
println!("{}", segtree.prod(l..r).0);
6363
}
6464
_ => {}
6565
}

src/fenwicktree.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::{Bound, RangeBounds};
2+
13
// Reference: https://en.wikipedia.org/wiki/Fenwick_tree
24
pub struct FenwickTree<T> {
35
n: usize,
@@ -34,17 +36,29 @@ impl<T: Clone + std::ops::AddAssign<T>> FenwickTree<T> {
3436
}
3537
}
3638
/// Returns data[l] + ... + data[r - 1].
37-
pub fn sum(&self, l: usize, r: usize) -> T
39+
pub fn sum<R>(&self, range: R) -> T
3840
where
3941
T: std::ops::Sub<Output = T>,
42+
R: RangeBounds<usize>,
4043
{
44+
let r = match range.end_bound() {
45+
Bound::Included(r) => r + 1,
46+
Bound::Excluded(r) => *r,
47+
Bound::Unbounded => self.n,
48+
};
49+
let l = match range.start_bound() {
50+
Bound::Included(l) => *l,
51+
Bound::Excluded(l) => l + 1,
52+
Bound::Unbounded => return self.accum(r),
53+
};
4154
self.accum(r) - self.accum(l)
4255
}
4356
}
4457

4558
#[cfg(test)]
4659
mod tests {
4760
use super::*;
61+
use std::ops::Bound::*;
4862

4963
#[test]
5064
fn fenwick_tree_works() {
@@ -53,8 +67,15 @@ mod tests {
5367
for i in 0..5 {
5468
bit.add(i, i as i64 + 1);
5569
}
56-
assert_eq!(bit.sum(0, 5), 15);
57-
assert_eq!(bit.sum(0, 4), 10);
58-
assert_eq!(bit.sum(1, 3), 5);
70+
assert_eq!(bit.sum(0..5), 15);
71+
assert_eq!(bit.sum(0..4), 10);
72+
assert_eq!(bit.sum(1..3), 5);
73+
74+
assert_eq!(bit.sum(..), 15);
75+
assert_eq!(bit.sum(..2), 3);
76+
assert_eq!(bit.sum(..=2), 6);
77+
assert_eq!(bit.sum(1..), 14);
78+
assert_eq!(bit.sum(1..=3), 9);
79+
assert_eq!(bit.sum((Excluded(0), Included(2))), 5);
5980
}
6081
}

src/lazysegtree.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,27 @@ impl<F: MapMonoid> LazySegtree<F> {
7373
self.d[p].clone()
7474
}
7575

76-
pub fn prod(&mut self, mut l: usize, mut r: usize) -> <F::M as Monoid>::S {
76+
pub fn prod<R>(&mut self, range: R) -> <F::M as Monoid>::S
77+
where
78+
R: RangeBounds<usize>,
79+
{
80+
// Trivial optimization
81+
if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
82+
return self.all_prod();
83+
}
84+
85+
let mut r = match range.end_bound() {
86+
Bound::Included(r) => r + 1,
87+
Bound::Excluded(r) => *r,
88+
Bound::Unbounded => self.n,
89+
};
90+
let mut l = match range.start_bound() {
91+
Bound::Included(l) => *l,
92+
Bound::Excluded(l) => l + 1,
93+
// TODO: There are another way of optimizing [0..r)
94+
Bound::Unbounded => 0,
95+
};
96+
7797
assert!(l <= r && r <= self.n);
7898
if l == r {
7999
return F::identity_element();
@@ -287,7 +307,10 @@ where
287307
}
288308

289309
// TODO is it useful?
290-
use std::fmt::{Debug, Error, Formatter, Write};
310+
use std::{
311+
fmt::{Debug, Error, Formatter, Write},
312+
ops::{Bound, RangeBounds},
313+
};
291314
impl<F> Debug for LazySegtree<F>
292315
where
293316
F: MapMonoid,
@@ -314,6 +337,8 @@ where
314337

315338
#[cfg(test)]
316339
mod tests {
340+
use std::ops::{Bound::*, RangeBounds};
341+
317342
use crate::{LazySegtree, MapMonoid, Max};
318343

319344
struct MaxAdd;
@@ -373,12 +398,20 @@ mod tests {
373398
for i in 0..n {
374399
assert_eq!(segtree.get(i), base[i]);
375400
}
401+
402+
check(base, segtree, ..);
376403
for i in 0..=n {
404+
check(base, segtree, ..i);
405+
check(base, segtree, i..);
406+
if i < n {
407+
check(base, segtree, ..=i);
408+
}
377409
for j in i..=n {
378-
assert_eq!(
379-
segtree.prod(i, j),
380-
base[i..j].iter().max().copied().unwrap_or(i32::min_value())
381-
);
410+
check(base, segtree, i..j);
411+
if j < n {
412+
check(base, segtree, i..=j);
413+
check(base, segtree, (Excluded(i), Included(j)));
414+
}
382415
}
383416
}
384417
assert_eq!(
@@ -413,4 +446,15 @@ mod tests {
413446
}
414447
}
415448
}
449+
450+
fn check(base: &[i32], segtree: &mut LazySegtree<MaxAdd>, range: impl RangeBounds<usize>) {
451+
let expected = base
452+
.iter()
453+
.enumerate()
454+
.filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
455+
.max()
456+
.copied()
457+
.unwrap_or(i32::min_value());
458+
assert_eq!(segtree.prod(range), expected);
459+
}
416460
}

src/segtree.rs

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::internal_type_traits::{BoundedAbove, BoundedBelow, One, Zero};
33
use std::cmp::{max, min};
44
use std::convert::Infallible;
55
use std::marker::PhantomData;
6-
use std::ops::{Add, Mul};
6+
use std::ops::{Add, Bound, Mul, RangeBounds};
77

88
// TODO Should I split monoid-related traits to another module?
99
pub trait Monoid {
@@ -107,7 +107,27 @@ impl<M: Monoid> Segtree<M> {
107107
self.d[p + self.size].clone()
108108
}
109109

110-
pub fn prod(&self, mut l: usize, mut r: usize) -> M::S {
110+
pub fn prod<R>(&self, range: R) -> M::S
111+
where
112+
R: RangeBounds<usize>,
113+
{
114+
// Trivial optimization
115+
if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
116+
return self.all_prod();
117+
}
118+
119+
let mut r = match range.end_bound() {
120+
Bound::Included(r) => r + 1,
121+
Bound::Excluded(r) => *r,
122+
Bound::Unbounded => self.n,
123+
};
124+
let mut l = match range.start_bound() {
125+
Bound::Included(l) => *l,
126+
Bound::Excluded(l) => l + 1,
127+
// TODO: There are another way of optimizing [0..r)
128+
Bound::Unbounded => 0,
129+
};
130+
111131
assert!(l <= r && r <= self.n);
112132
let mut sml = M::identity();
113133
let mut smr = M::identity();
@@ -240,6 +260,7 @@ where
240260
mod tests {
241261
use crate::segtree::Max;
242262
use crate::Segtree;
263+
use std::ops::{Bound::*, RangeBounds};
243264

244265
#[test]
245266
fn test_max_segtree() {
@@ -272,12 +293,20 @@ mod tests {
272293
for i in 0..n {
273294
assert_eq!(segtree.get(i), base[i]);
274295
}
296+
297+
check(base, segtree, ..);
275298
for i in 0..=n {
299+
check(base, segtree, ..i);
300+
check(base, segtree, i..);
301+
if i < n {
302+
check(base, segtree, ..=i);
303+
}
276304
for j in i..=n {
277-
assert_eq!(
278-
segtree.prod(i, j),
279-
base[i..j].iter().max().copied().unwrap_or(i32::min_value())
280-
);
305+
check(base, segtree, i..j);
306+
if j < n {
307+
check(base, segtree, i..=j);
308+
check(base, segtree, (Excluded(i), Included(j)));
309+
}
281310
}
282311
}
283312
assert_eq!(
@@ -312,4 +341,15 @@ mod tests {
312341
}
313342
}
314343
}
344+
345+
fn check(base: &[i32], segtree: &Segtree<Max<i32>>, range: impl RangeBounds<usize>) {
346+
let expected = base
347+
.iter()
348+
.enumerate()
349+
.filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
350+
.max()
351+
.copied()
352+
.unwrap_or(i32::min_value());
353+
assert_eq!(segtree.prod(range), expected);
354+
}
315355
}

0 commit comments

Comments
 (0)