Skip to content

Commit 9b4e612

Browse files
SkiFire13Mark-Simulacrum
authored andcommitted
Document BinaryHeap unsafe functions
1 parent e7c23ab commit 9b4e612

File tree

1 file changed

+113
-49
lines changed

1 file changed

+113
-49
lines changed

library/alloc/src/collections/binary_heap.rs

+113-49
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
275275
impl<T: Ord> Drop for PeekMut<'_, T> {
276276
fn drop(&mut self) {
277277
if self.sift {
278-
self.heap.sift_down(0);
278+
// SAFETY: PeekMut is only instantiated for non-empty heaps.
279+
unsafe { self.heap.sift_down(0) };
279280
}
280281
}
281282
}
@@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
431432
self.data.pop().map(|mut item| {
432433
if !self.is_empty() {
433434
swap(&mut item, &mut self.data[0]);
434-
self.sift_down_to_bottom(0);
435+
// SAFETY: !self.is_empty() means that self.len() > 0
436+
unsafe { self.sift_down_to_bottom(0) };
435437
}
436438
item
437439
})
@@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
473475
pub fn push(&mut self, item: T) {
474476
let old_len = self.len();
475477
self.data.push(item);
476-
self.sift_up(0, old_len);
478+
// SAFETY: Since we pushed a new item it means that
479+
// old_len = self.len() - 1 < self.len()
480+
unsafe { self.sift_up(0, old_len) };
477481
}
478482

479483
/// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
506510
let ptr = self.data.as_mut_ptr();
507511
ptr::swap(ptr, ptr.add(end));
508512
}
509-
self.sift_down_range(0, end);
513+
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
514+
// 0 < 1 <= end <= self.len() - 1 < self.len()
515+
// Which means 0 < end and end < self.len().
516+
unsafe { self.sift_down_range(0, end) };
510517
}
511518
self.into_vec()
512519
}
@@ -519,78 +526,135 @@ impl<T: Ord> BinaryHeap<T> {
519526
// the hole is filled back at the end of its scope, even on panic.
520527
// Using a hole reduces the constant factor compared to using swaps,
521528
// which involves twice as many moves.
522-
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
523-
unsafe {
524-
// Take out the value at `pos` and create a hole.
525-
let mut hole = Hole::new(&mut self.data, pos);
526-
527-
while hole.pos() > start {
528-
let parent = (hole.pos() - 1) / 2;
529-
if hole.element() <= hole.get(parent) {
530-
break;
531-
}
532-
hole.move_to(parent);
529+
530+
/// # Safety
531+
///
532+
/// The caller must guarantee that `pos < self.len()`.
533+
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
534+
// Take out the value at `pos` and create a hole.
535+
// SAFETY: The caller guarantees that pos < self.len()
536+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
537+
538+
while hole.pos() > start {
539+
let parent = (hole.pos() - 1) / 2;
540+
541+
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
542+
// and so hole.pos() - 1 can't underflow.
543+
// This guarantees that parent < hole.pos() so
544+
// it's a valid index and also != hole.pos().
545+
if hole.element() <= unsafe { hole.get(parent) } {
546+
break;
533547
}
534-
hole.pos()
548+
549+
// SAFETY: Same as above
550+
unsafe { hole.move_to(parent) };
535551
}
552+
553+
hole.pos()
536554
}
537555

538556
/// Take an element at `pos` and move it down the heap,
539557
/// while its children are larger.
540-
fn sift_down_range(&mut self, pos: usize, end: usize) {
541-
unsafe {
542-
let mut hole = Hole::new(&mut self.data, pos);
543-
let mut child = 2 * pos + 1;
544-
while child < end - 1 {
545-
// compare with the greater of the two children
546-
child += (hole.get(child) <= hole.get(child + 1)) as usize;
547-
// if we are already in order, stop.
548-
if hole.element() >= hole.get(child) {
549-
return;
550-
}
551-
hole.move_to(child);
552-
child = 2 * hole.pos() + 1;
553-
}
554-
if child == end - 1 && hole.element() < hole.get(child) {
555-
hole.move_to(child);
558+
///
559+
/// # Safety
560+
///
561+
/// The caller must guarantee that `pos < end <= self.len()`.
562+
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
563+
// SAFETY: The caller guarantees that pos < end <= self.len().
564+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
565+
let mut child = 2 * hole.pos() + 1;
566+
567+
// Loop invariant: child == 2 * hole.pos() + 1.
568+
while child < end - 1 {
569+
// compare with the greater of the two children
570+
// SAFETY: child < end - 1 < self.len() and
571+
// child + 1 < end <= self.len(), so they're valid indexes.
572+
// child == 2 * hole.pos() + 1 != hole.pos() and
573+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
574+
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
575+
576+
// if we are already in order, stop.
577+
// SAFETY: child is now either the old child or the old child+1
578+
// We already proven that both are < self.len() and != hole.pos()
579+
if hole.element() >= unsafe { hole.get(child) } {
580+
return;
556581
}
582+
583+
// SAFETY: same as above.
584+
unsafe { hole.move_to(child) };
585+
child = 2 * hole.pos() + 1;
586+
}
587+
588+
// SAFETY: && short circuit, which means that in the
589+
// second condition it's already true that child == end - 1 < self.len().
590+
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
591+
// SAFETY: child is already proven to be a valid index and
592+
// child == 2 * hole.pos() + 1 != hole.pos().
593+
unsafe { hole.move_to(child) };
557594
}
558595
}
559596

560-
fn sift_down(&mut self, pos: usize) {
597+
/// # Safety
598+
///
599+
/// The caller must guarantee that `pos < self.len()`.
600+
unsafe fn sift_down(&mut self, pos: usize) {
561601
let len = self.len();
562-
self.sift_down_range(pos, len);
602+
// SAFETY: pos < len is guaranteed by the caller and
603+
// obviously len = self.len() <= self.len().
604+
unsafe { self.sift_down_range(pos, len) };
563605
}
564606

565607
/// Take an element at `pos` and move it all the way down the heap,
566608
/// then sift it up to its position.
567609
///
568610
/// Note: This is faster when the element is known to be large / should
569611
/// be closer to the bottom.
570-
fn sift_down_to_bottom(&mut self, mut pos: usize) {
612+
///
613+
/// # Safety
614+
///
615+
/// The caller must guarantee that `pos < self.len()`.
616+
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
571617
let end = self.len();
572618
let start = pos;
573-
unsafe {
574-
let mut hole = Hole::new(&mut self.data, pos);
575-
let mut child = 2 * pos + 1;
576-
while child < end - 1 {
577-
child += (hole.get(child) <= hole.get(child + 1)) as usize;
578-
hole.move_to(child);
579-
child = 2 * hole.pos() + 1;
580-
}
581-
if child == end - 1 {
582-
hole.move_to(child);
583-
}
584-
pos = hole.pos;
619+
620+
// SAFETY: The caller guarantees that pos < self.len().
621+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
622+
let mut child = 2 * hole.pos() + 1;
623+
624+
// Loop invariant: child == 2 * hole.pos() + 1.
625+
while child < end - 1 {
626+
// SAFETY: child < end - 1 < self.len() and
627+
// child + 1 < end <= self.len(), so they're valid indexes.
628+
// child == 2 * hole.pos() + 1 != hole.pos() and
629+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
630+
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
631+
632+
// SAFETY: Same as above
633+
unsafe { hole.move_to(child) };
634+
child = 2 * hole.pos() + 1;
585635
}
586-
self.sift_up(start, pos);
636+
637+
if child == end - 1 {
638+
// SAFETY: child == end - 1 < self.len(), so it's a valid index
639+
// and child == 2 * hole.pos() + 1 != hole.pos().
640+
unsafe { hole.move_to(child) };
641+
}
642+
pos = hole.pos();
643+
drop(hole);
644+
645+
// SAFETY: pos is the position in the hole and was already proven
646+
// to be a valid index.
647+
unsafe { self.sift_up(start, pos) };
587648
}
588649

589650
fn rebuild(&mut self) {
590651
let mut n = self.len() / 2;
591652
while n > 0 {
592653
n -= 1;
593-
self.sift_down(n);
654+
// SAFETY: n starts from self.len() / 2 and goes down to 0.
655+
// The only case when !(n < self.len()) is if
656+
// self.len() == 0, but it's ruled out by the loop condition.
657+
unsafe { self.sift_down(n) };
594658
}
595659
}
596660

0 commit comments

Comments
 (0)