Skip to content

Commit 341e550

Browse files
bors[bot]Amanieu
andcommitted
Merge #38
38: Fix memory leak in par_drain and par_into_iter r=Amanieu a=Amanieu Found by @cuviper in #37 Co-authored-by: Amanieu d'Antras <[email protected]>
2 parents 500b9bc + 5c9626f commit 341e550

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

src/external_trait_impls/rayon/map.rs

+12
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,12 @@ mod test_par_map {
496496
// By the way, ensure that cloning doesn't screw up the dropping.
497497
drop(hm.clone());
498498

499+
assert_eq!(key.load(Ordering::Relaxed), 100);
500+
assert_eq!(value.load(Ordering::Relaxed), 100);
501+
502+
// Ensure that dropping the iterator does not leak anything.
503+
drop(hm.clone().into_par_iter());
504+
499505
{
500506
assert_eq!(key.load(Ordering::Relaxed), 100);
501507
assert_eq!(value.load(Ordering::Relaxed), 100);
@@ -540,6 +546,12 @@ mod test_par_map {
540546
// By the way, ensure that cloning doesn't screw up the dropping.
541547
drop(hm.clone());
542548

549+
assert_eq!(key.load(Ordering::Relaxed), 100);
550+
assert_eq!(value.load(Ordering::Relaxed), 100);
551+
552+
// Ensure that dropping the drain iterator does not leak anything.
553+
drop(hm.clone().par_drain());
554+
543555
{
544556
assert_eq!(key.load(Ordering::Relaxed), 100);
545557
assert_eq!(value.load(Ordering::Relaxed), 100);

src/external_trait_impls/rayon/raw.rs

+35-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use alloc::alloc::{dealloc, Layout};
1+
use alloc::alloc::dealloc;
2+
use core::marker::PhantomData;
23
use core::mem;
34
use core::ptr::NonNull;
45
use raw::Bucket;
@@ -58,8 +59,7 @@ impl<T> UnindexedProducer for ParIterProducer<T> {
5859

5960
/// Parallel iterator which consumes a table and returns elements.
6061
pub struct RawIntoParIter<T> {
61-
iter: RawIterRange<T>,
62-
alloc: Option<(NonNull<u8>, Layout)>,
62+
table: RawTable<T>,
6363
}
6464

6565
unsafe impl<T> Send for RawIntoParIter<T> {}
@@ -72,21 +72,25 @@ impl<T: Send> ParallelIterator for RawIntoParIter<T> {
7272
where
7373
C: UnindexedConsumer<Self::Item>,
7474
{
75-
let _guard = guard(self.alloc, |alloc| {
75+
let iter = unsafe { self.table.iter().iter };
76+
let _guard = guard(self.table.into_alloc(), |alloc| {
7677
if let Some((ptr, layout)) = *alloc {
7778
unsafe {
7879
dealloc(ptr.as_ptr(), layout);
7980
}
8081
}
8182
});
82-
let producer = ParDrainProducer { iter: self.iter };
83+
let producer = ParDrainProducer { iter };
8384
plumbing::bridge_unindexed(producer, consumer)
8485
}
8586
}
8687

8788
/// Parallel iterator which consumes elements without freeing the table storage.
8889
pub struct RawParDrain<'a, T> {
89-
table: &'a mut RawTable<T>,
90+
// We don't use a &'a RawTable<T> because we want RawParDrain to be
91+
// covariant over 'a.
92+
table: NonNull<RawTable<T>>,
93+
_marker: PhantomData<&'a RawTable<T>>,
9094
}
9195

9296
unsafe impl<'a, T> Send for RawParDrain<'a, T> {}
@@ -99,13 +103,23 @@ impl<'a, T: Send> ParallelIterator for RawParDrain<'a, T> {
99103
where
100104
C: UnindexedConsumer<Self::Item>,
101105
{
102-
let iter = unsafe { self.table.iter().iter };
103-
let _guard = guard(self.table, |table| table.clear_no_drop());
106+
let _guard = guard(self.table, |table| unsafe {
107+
table.as_mut().clear_no_drop()
108+
});
109+
let iter = unsafe { self.table.as_ref().iter().iter };
110+
mem::forget(self);
104111
let producer = ParDrainProducer { iter };
105112
plumbing::bridge_unindexed(producer, consumer)
106113
}
107114
}
108115

116+
impl<'a, T> Drop for RawParDrain<'a, T> {
117+
fn drop(&mut self) {
118+
// If drive_unindexed is not called then simply clear the table.
119+
unsafe { self.table.as_mut().clear() }
120+
}
121+
}
122+
109123
/// Producer which will consume all elements in the range, even if it is dropped
110124
/// halfway through.
111125
struct ParDrainProducer<T> {
@@ -136,20 +150,23 @@ impl<T: Send> UnindexedProducer for ParDrainProducer<T> {
136150
while let Some(item) = self.iter.next() {
137151
folder = folder.consume(unsafe { item.read() });
138152
if folder.full() {
139-
break;
153+
return folder;
140154
}
141155
}
156+
157+
// If we processed all elements then we don't need to run the drop.
158+
mem::forget(self);
142159
folder
143160
}
144161
}
145162

146163
impl<T> Drop for ParDrainProducer<T> {
147164
#[inline]
148165
fn drop(&mut self) {
149-
unsafe {
150-
// Drop all remaining elements
151-
if mem::needs_drop::<T>() {
152-
while let Some(item) = self.iter.next() {
166+
// Drop all remaining elements
167+
if mem::needs_drop::<T>() {
168+
while let Some(item) = self.iter.next() {
169+
unsafe {
153170
item.drop();
154171
}
155172
}
@@ -169,16 +186,16 @@ impl<T> RawTable<T> {
169186
/// Returns a parallel iterator over the elements in a `RawTable`.
170187
#[inline]
171188
pub fn into_par_iter(self) -> RawIntoParIter<T> {
172-
RawIntoParIter {
173-
iter: unsafe { self.iter().iter },
174-
alloc: self.into_alloc(),
175-
}
189+
RawIntoParIter { table: self }
176190
}
177191

178192
/// Returns a parallel iterator which consumes all elements of a `RawTable`
179193
/// without freeing its memory allocation.
180194
#[inline]
181195
pub fn par_drain(&mut self) -> RawParDrain<T> {
182-
RawParDrain { table: self }
196+
RawParDrain {
197+
table: NonNull::from(self),
198+
_marker: PhantomData,
199+
}
183200
}
184201
}

0 commit comments

Comments
 (0)