Skip to content

Commit bbb36fe

Browse files
committed
Auto merge of #105851 - dtolnay:peekmutleak, r=Mark-Simulacrum
Leak amplification for peek_mut() to ensure BinaryHeap's invariant is always met In the libs-api team's discussion around #104210, some of the team had hesitations around exposing malformed BinaryHeaps of an element type whose Ord and Drop impls are trusted, and which does not contain interior mutability. For example in the context of this kind of code: ```rust use std::collections::BinaryHeap; use std::ops::Range; use std::slice; fn main() { let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']; let cut_points = BinaryHeap::from(vec![4, 2, 7]); println!("{:?}", chop(slice, cut_points)); } // This is a souped up slice::split_at_mut to split in arbitrary many places. // // usize's Ord impl is trusted, so 1 single bounds check guarantees all those // output slices are non-overlapping and in-bounds fn chop<T>(slice: &mut [T], mut cut_points: BinaryHeap<usize>) -> Vec<&mut [T]> { let mut vec = Vec::with_capacity(cut_points.len() + 1); let max = match cut_points.pop() { Some(max) => max, None => { vec.push(slice); return vec; } }; assert!(max <= slice.len()); let len = slice.len(); let ptr: *mut T = slice.as_mut_ptr(); let get_unchecked_mut = unsafe { |range: Range<usize>| &mut *slice::from_raw_parts_mut(ptr.add(range.start), range.len()) }; vec.push(get_unchecked_mut(max..len)); let mut end = max; while let Some(start) = cut_points.pop() { vec.push(get_unchecked_mut(start..end)); end = start; } vec.push(get_unchecked_mut(0..end)); vec } ``` ```console [['7', '8', '9'], ['4', '5', '6'], ['2', '3'], ['0', '1']] ``` In the current BinaryHeap API, `peek_mut()` is the only thing that makes the above function unsound. ```rust let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']; let mut cut_points = BinaryHeap::from(vec![4, 2, 7]); { let mut max = cut_points.peek_mut().unwrap(); *max = 0; std::mem::forget(max); } println!("{:?}", chop(slice, cut_points)); ``` ```console [['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], [], ['2', '3'], ['0', '1']] ``` Or worse: ```rust let slice = &mut ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']; let mut cut_points = BinaryHeap::from(vec![100, 100]); { let mut max = cut_points.peek_mut().unwrap(); *max = 0; std::mem::forget(max); } println!("{:?}", chop(slice, cut_points)); ``` ```console [['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], [], ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '\u{1}', '\0', '?', '翾', '?', '翾', '\0', '\0', '?', '翾', '?', '翾', '?', '啿', '?', '啿', '?', '啿', '?', '啿', '?', '啿', '?', '翾', '\0', '\0', '񤬐', '啿', '\u{5}', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\u{8}', '\0', '`@',` '\0', '\u{1}', '\0', '?', '翾', '?', '翾', '?', '翾', ' thread 'main' panicked at 'index out of bounds: the len is 33 but the index is 33', library/core/src/unicode/unicode_data.rs:319:9 note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace ``` --- This PR makes `peek_mut()` use leak amplification (https://doc.rust-lang.org/1.66.0/nomicon/leaking.html#drain) to preserve the heap's invariant even in the situation that `PeekMut` gets leaked. I'll also follow up in the tracking issue of unstable `drain_sorted()` (#59278) and `retain()` (#71503).
2 parents 754f6d4 + 2350170 commit bbb36fe

File tree

3 files changed

+75
-10
lines changed

3 files changed

+75
-10
lines changed

library/alloc/src/collections/binary_heap/mod.rs

+55-10
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
use core::fmt;
147147
use core::iter::{FromIterator, FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
148148
use core::mem::{self, swap, ManuallyDrop};
149+
use core::num::NonZeroUsize;
149150
use core::ops::{Deref, DerefMut};
150151
use core::ptr;
151152

@@ -165,12 +166,20 @@ mod tests;
165166
/// It is a logic error for an item to be modified in such a way that the
166167
/// item's ordering relative to any other item, as determined by the [`Ord`]
167168
/// trait, changes while it is in the heap. This is normally only possible
168-
/// through [`Cell`], [`RefCell`], global state, I/O, or unsafe code. The
169+
/// through interior mutability, global state, I/O, or unsafe code. The
169170
/// behavior resulting from such a logic error is not specified, but will
170171
/// be encapsulated to the `BinaryHeap` that observed the logic error and not
171172
/// result in undefined behavior. This could include panics, incorrect results,
172173
/// aborts, memory leaks, and non-termination.
173174
///
175+
/// As long as no elements change their relative order while being in the heap
176+
/// as described above, the API of `BinaryHeap` guarantees that the heap
177+
/// invariant remains intact i.e. its methods all behave as documented. For
178+
/// example if a method is documented as iterating in sorted order, that's
179+
/// guaranteed to work as long as elements in the heap have not changed order,
180+
/// even in the presence of closures getting unwinded out of, iterators getting
181+
/// leaked, and similar foolishness.
182+
///
174183
/// # Examples
175184
///
176185
/// ```
@@ -279,7 +288,9 @@ pub struct BinaryHeap<T> {
279288
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
280289
pub struct PeekMut<'a, T: 'a + Ord> {
281290
heap: &'a mut BinaryHeap<T>,
282-
sift: bool,
291+
// If a set_len + sift_down are required, this is Some. If a &mut T has not
292+
// yet been exposed to peek_mut()'s caller, it's None.
293+
original_len: Option<NonZeroUsize>,
283294
}
284295

285296
#[stable(feature = "collection_debug", since = "1.17.0")]
@@ -292,7 +303,14 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
292303
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
293304
impl<T: Ord> Drop for PeekMut<'_, T> {
294305
fn drop(&mut self) {
295-
if self.sift {
306+
if let Some(original_len) = self.original_len {
307+
// SAFETY: That's how many elements were in the Vec at the time of
308+
// the PeekMut::deref_mut call, and therefore also at the time of
309+
// the BinaryHeap::peek_mut call. Since the PeekMut did not end up
310+
// getting leaked, we are now undoing the leak amplification that
311+
// the DerefMut prepared for.
312+
unsafe { self.heap.data.set_len(original_len.get()) };
313+
296314
// SAFETY: PeekMut is only instantiated for non-empty heaps.
297315
unsafe { self.heap.sift_down(0) };
298316
}
@@ -313,7 +331,26 @@ impl<T: Ord> Deref for PeekMut<'_, T> {
313331
impl<T: Ord> DerefMut for PeekMut<'_, T> {
314332
fn deref_mut(&mut self) -> &mut T {
315333
debug_assert!(!self.heap.is_empty());
316-
self.sift = true;
334+
335+
let len = self.heap.len();
336+
if len > 1 {
337+
// Here we preemptively leak all the rest of the underlying vector
338+
// after the currently max element. If the caller mutates the &mut T
339+
// we're about to give them, and then leaks the PeekMut, all these
340+
// elements will remain leaked. If they don't leak the PeekMut, then
341+
// either Drop or PeekMut::pop will un-leak the vector elements.
342+
//
343+
// This is technique is described throughout several other places in
344+
// the standard library as "leak amplification".
345+
unsafe {
346+
// SAFETY: len > 1 so len != 0.
347+
self.original_len = Some(NonZeroUsize::new_unchecked(len));
348+
// SAFETY: len > 1 so all this does for now is leak elements,
349+
// which is safe.
350+
self.heap.data.set_len(1);
351+
}
352+
}
353+
317354
// SAFE: PeekMut is only instantiated for non-empty heaps
318355
unsafe { self.heap.data.get_unchecked_mut(0) }
319356
}
@@ -323,9 +360,16 @@ impl<'a, T: Ord> PeekMut<'a, T> {
323360
/// Removes the peeked value from the heap and returns it.
324361
#[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")]
325362
pub fn pop(mut this: PeekMut<'a, T>) -> T {
326-
let value = this.heap.pop().unwrap();
327-
this.sift = false;
328-
value
363+
if let Some(original_len) = this.original_len.take() {
364+
// SAFETY: This is how many elements were in the Vec at the time of
365+
// the BinaryHeap::peek_mut call.
366+
unsafe { this.heap.data.set_len(original_len.get()) };
367+
368+
// Unlike in Drop, here we don't also need to do a sift_down even if
369+
// the caller could've mutated the element. It is removed from the
370+
// heap on the next line and pop() is not sensitive to its value.
371+
}
372+
this.heap.pop().unwrap()
329373
}
330374
}
331375

@@ -398,8 +442,9 @@ impl<T: Ord> BinaryHeap<T> {
398442
/// Returns a mutable reference to the greatest item in the binary heap, or
399443
/// `None` if it is empty.
400444
///
401-
/// Note: If the `PeekMut` value is leaked, the heap may be in an
402-
/// inconsistent state.
445+
/// Note: If the `PeekMut` value is leaked, some heap elements might get
446+
/// leaked along with it, but the remaining elements will remain a valid
447+
/// heap.
403448
///
404449
/// # Examples
405450
///
@@ -426,7 +471,7 @@ impl<T: Ord> BinaryHeap<T> {
426471
/// otherwise it's *O*(1).
427472
#[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
428473
pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T>> {
429-
if self.is_empty() { None } else { Some(PeekMut { heap: self, sift: false }) }
474+
if self.is_empty() { None } else { Some(PeekMut { heap: self, original_len: None }) }
430475
}
431476

432477
/// Removes the greatest item from the binary heap and returns it, or `None` if it

library/alloc/src/collections/binary_heap/tests.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::*;
22
use crate::boxed::Box;
33
use crate::testing::crash_test::{CrashTestDummy, Panic};
4+
use core::mem;
45
use std::iter::TrustedLen;
56
use std::panic::{catch_unwind, AssertUnwindSafe};
67

@@ -146,6 +147,24 @@ fn test_peek_mut() {
146147
assert_eq!(heap.peek(), Some(&9));
147148
}
148149

150+
#[test]
151+
fn test_peek_mut_leek() {
152+
let data = vec![4, 2, 7];
153+
let mut heap = BinaryHeap::from(data);
154+
let mut max = heap.peek_mut().unwrap();
155+
*max = -1;
156+
157+
// The PeekMut object's Drop impl would have been responsible for moving the
158+
// -1 out of the max position of the BinaryHeap, but we don't run it.
159+
mem::forget(max);
160+
161+
// Absent some mitigation like leak amplification, the -1 would incorrectly
162+
// end up in the last position of the returned Vec, with the rest of the
163+
// heap's original contents in front of it in sorted order.
164+
let sorted_vec = heap.into_sorted_vec();
165+
assert!(sorted_vec.is_sorted(), "{:?}", sorted_vec);
166+
}
167+
149168
#[test]
150169
fn test_peek_mut_pop() {
151170
let data = vec![2, 4, 6, 2, 1, 8, 10, 3, 5, 7, 0, 9, 1];

library/alloc/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
#![feature(hasher_prefixfree_extras)]
126126
#![feature(inline_const)]
127127
#![feature(inplace_iteration)]
128+
#![cfg_attr(test, feature(is_sorted))]
128129
#![feature(iter_advance_by)]
129130
#![feature(iter_next_chunk)]
130131
#![feature(iter_repeat_n)]

0 commit comments

Comments
 (0)