Skip to content

Commit 00756e0

Browse files
bors[bot]nbraudSkiFire13
authored
Merge #473
473: Add helper method for taking the k smallest elements in an iterator r=jswrenn a=nbraud Co-authored-by: nicoo <[email protected]> Co-authored-by: Giacomo Stevanato <[email protected]>
2 parents 130ffd3 + f28ffd0 commit 00756e0

File tree

4 files changed

+147
-1
lines changed

4 files changed

+147
-1
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ either = { version = "1.0", default-features = false }
2828
[dev-dependencies]
2929
rand = "0.7"
3030
criterion = "=0" # TODO how could this work with our minimum supported rust version?
31+
paste = "1.0.0" # Used in test_std to instanciate generic tests
3132

3233
[dev-dependencies.quickcheck]
3334
version = "0.9"

src/k_smallest.rs

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use alloc::collections::BinaryHeap;
2+
use core::cmp::Ord;
3+
4+
pub(crate) fn k_smallest<T: Ord, I: Iterator<Item = T>>(mut iter: I, k: usize) -> BinaryHeap<T> {
5+
if k == 0 { return BinaryHeap::new(); }
6+
7+
let mut heap = iter.by_ref().take(k).collect::<BinaryHeap<_>>();
8+
9+
for i in iter {
10+
debug_assert_eq!(heap.len(), k);
11+
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
12+
// This should be done with a single `.peek_mut().unwrap()` but
13+
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
14+
if *heap.peek().unwrap() > i {
15+
*heap.peek_mut().unwrap() = i;
16+
}
17+
}
18+
19+
heap
20+
}

src/lib.rs

+39
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ mod group_map;
196196
mod groupbylazy;
197197
mod intersperse;
198198
#[cfg(feature = "use_alloc")]
199+
mod k_smallest;
200+
#[cfg(feature = "use_alloc")]
199201
mod kmerge_impl;
200202
#[cfg(feature = "use_alloc")]
201203
mod lazy_buffer;
@@ -2419,6 +2421,43 @@ pub trait Itertools : Iterator {
24192421
v.into_iter()
24202422
}
24212423

2424+
/// Sort the k smallest elements into a new iterator, in ascending order.
2425+
///
2426+
/// **Note:** This consumes the entire iterator, and returns the result
2427+
/// as a new iterator that owns its elements. If the input contains
2428+
/// less than k elements, the result is equivalent to `self.sorted()`.
2429+
///
2430+
/// This is guaranteed to use `k * sizeof(Self::Item) + O(1)` memory
2431+
/// and `O(n log k)` time, with `n` the number of elements in the input.
2432+
///
2433+
/// The sorted iterator, if directly collected to a `Vec`, is converted
2434+
/// without any extra copying or allocation cost.
2435+
///
2436+
/// **Note:** This is functionally-equivalent to `self.sorted().take(k)`
2437+
/// but much more efficient.
2438+
///
2439+
/// ```
2440+
/// use itertools::Itertools;
2441+
///
2442+
/// // A random permutation of 0..15
2443+
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
2444+
///
2445+
/// let five_smallest = numbers
2446+
/// .into_iter()
2447+
/// .k_smallest(5);
2448+
///
2449+
/// itertools::assert_equal(five_smallest, 0..5);
2450+
/// ```
2451+
#[cfg(feature = "use_alloc")]
2452+
fn k_smallest(self, k: usize) -> VecIntoIter<Self::Item>
2453+
where Self: Sized,
2454+
Self::Item: Ord
2455+
{
2456+
crate::k_smallest::k_smallest(self, k)
2457+
.into_sorted_vec()
2458+
.into_iter()
2459+
}
2460+
24222461
/// Collect all iterator elements into one of two
24232462
/// partitions. Unlike `Iterator::partition`, each partition may
24242463
/// have a distinct type.

tests/test_std.rs

+87-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
use paste;
12
use permutohedron;
3+
use quickcheck as qc;
4+
use rand::{distributions::{Distribution, Standard}, Rng, SeedableRng, rngs::StdRng};
5+
use rand::{seq::SliceRandom, thread_rng};
6+
use std::{cmp::min, fmt::Debug, marker::PhantomData};
27
use itertools as it;
38
use crate::it::Itertools;
49
use crate::it::ExactlyOneError;
@@ -374,6 +379,88 @@ fn sorted_by() {
374379
it::assert_equal(v, vec![4, 3, 2, 1, 0]);
375380
}
376381

382+
qc::quickcheck! {
383+
fn k_smallest_range(n: u64, m: u16, k: u16) -> () {
384+
// u16 is used to constrain k and m to 0..2¹⁶,
385+
// otherwise the test could use too much memory.
386+
let (k, m) = (k as u64, m as u64);
387+
388+
// Generate a random permutation of n..n+m
389+
let i = {
390+
let mut v: Vec<u64> = (n..n.saturating_add(m)).collect();
391+
v.shuffle(&mut thread_rng());
392+
v.into_iter()
393+
};
394+
395+
// Check that taking the k smallest elements yields n..n+min(k, m)
396+
it::assert_equal(
397+
i.k_smallest(k as usize),
398+
n..n.saturating_add(min(k, m))
399+
);
400+
}
401+
}
402+
403+
#[derive(Clone, Debug)]
404+
struct RandIter<T: 'static + Clone + Send, R: 'static + Clone + Rng + SeedableRng + Send = StdRng> {
405+
idx: usize,
406+
len: usize,
407+
rng: R,
408+
_t: PhantomData<T>
409+
}
410+
411+
impl<T: Clone + Send, R: Clone + Rng + SeedableRng + Send> Iterator for RandIter<T, R>
412+
where Standard: Distribution<T> {
413+
type Item = T;
414+
fn next(&mut self) -> Option<T> {
415+
if self.idx == self.len {
416+
None
417+
} else {
418+
self.idx += 1;
419+
Some(self.rng.gen())
420+
}
421+
}
422+
}
423+
424+
impl<T: Clone + Send, R: Clone + Rng + SeedableRng + Send> qc::Arbitrary for RandIter<T, R> {
425+
fn arbitrary<G: qc::Gen>(g: &mut G) -> Self {
426+
RandIter {
427+
idx: 0,
428+
len: g.size(),
429+
rng: R::seed_from_u64(g.next_u64()),
430+
_t : PhantomData{},
431+
}
432+
}
433+
}
434+
435+
// Check that taking the k smallest is the same as
436+
// sorting then taking the k first elements
437+
fn k_smallest_sort<I>(i: I, k: u16) -> ()
438+
where
439+
I: Iterator + Clone,
440+
I::Item: Ord + Debug,
441+
{
442+
let j = i.clone();
443+
let k = k as usize;
444+
it::assert_equal(
445+
i.k_smallest(k),
446+
j.sorted().take(k)
447+
)
448+
}
449+
450+
macro_rules! generic_test {
451+
($f:ident, $($t:ty),+) => {
452+
$(paste::item! {
453+
qc::quickcheck! {
454+
fn [< $f _ $t >](i: RandIter<$t>, k: u16) -> () {
455+
$f(i, k)
456+
}
457+
}
458+
})+
459+
};
460+
}
461+
462+
generic_test!(k_smallest_sort, u8, u16, u32, u64, i8, i16, i32, i64);
463+
377464
#[test]
378465
fn sorted_by_key() {
379466
let sc = [3, 4, 1, 2].iter().cloned().sorted_by_key(|&x| x);
@@ -407,7 +494,6 @@ fn test_multipeek() {
407494
assert_eq!(mp.next(), Some(5));
408495
assert_eq!(mp.next(), None);
409496
assert_eq!(mp.peek(), None);
410-
411497
}
412498

413499
#[test]

0 commit comments

Comments
 (0)