Skip to content

Add in-place optimization for array map #75571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 46 additions & 26 deletions library/core/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,42 +425,62 @@ impl<T, const N: usize> [T; N] {
/// assert_eq!(y, [6, 9, 3, 3]);
/// ```
#[unstable(feature = "array_map", issue = "75243")]
pub fn map<F, U>(self, mut f: F) -> [U; N]
pub fn map<F, U>(self, f: F) -> [U; N]
where
F: FnMut(T) -> U,
{
use crate::mem::MaybeUninit;
struct Guard<T, const N: usize> {
dst: *mut T,
initialized: usize,
}
use crate::mem::{forget, ManuallyDrop, MaybeUninit};
use crate::ptr;

impl<T, const N: usize> Drop for Guard<T, N> {
union MaybeUninitArray<T, const N: usize> {
none: (),
partial: ManuallyDrop<[MaybeUninit<T>; N]>,
complete: ManuallyDrop<[T; N]>,
}
struct MapGuard<'a, T, const N: usize> {
arr: &'a mut MaybeUninitArray<T, N>,
len: usize,
}
impl<'a, T, const N: usize> MapGuard<'a, T, N> {
fn push(&mut self, value: T) {
// SAFETY: Since we know the input size is N, and the output is N,
// this will never exceed the capacity, and MaybeUninit is always in the
// structure of an array.
unsafe {
self.arr.partial[self.len].write(value);
self.len += 1;
}
}
}
impl<'a, T, const N: usize> Drop for MapGuard<'a, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);

let initialized_part =
crate::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
// SAFETY: this raw slice will contain only initialized objects
// that's why, it is allowed to drop it.
//debug_assert!(self.len <= N);
// SAFETY: already pushed `len` elements, but need to drop them now that
// `f` panicked.
unsafe {
crate::ptr::drop_in_place(initialized_part);
let p: *mut MaybeUninit<T> = self.arr.partial.as_mut_ptr();
let slice: *mut [T] = ptr::slice_from_raw_parts_mut(p.cast(), self.len);
ptr::drop_in_place(slice);
}
}
}
let mut dst = MaybeUninit::uninit_array::<N>();
let mut guard: Guard<U, N> =
Guard { dst: MaybeUninit::slice_as_mut_ptr(&mut dst), initialized: 0 };
for (src, dst) in IntoIter::new(self).zip(&mut dst) {
dst.write(f(src));
guard.initialized += 1;

fn map_guard_fn<T, const N: usize>(
buffer: &mut MaybeUninitArray<T, N>,
iter: impl Iterator<Item = T>,
) {
let mut guard = MapGuard { arr: buffer, len: 0 };
for v in iter {
guard.push(v);
}
forget(guard);
}
// FIXME: Convert to crate::mem::transmute once it works with generics.
// unsafe { crate::mem::transmute::<[MaybeUninit<U>; N], [U; N]>(dst) }
crate::mem::forget(guard);
// SAFETY: At this point we've properly initialized the whole array
// and we just need to cast it to the correct type.
unsafe { crate::mem::transmute_copy::<_, [U; N]>(&dst) }

let mut buffer = MaybeUninitArray::<U, N> { none: () };
map_guard_fn(&mut buffer, IntoIter::new(self).map(f));
// SAFETY: all elements have successfully initialized, don't run guard's drop code
// and take completed buffer out of MaybeUninitArray.
unsafe { ManuallyDrop::into_inner(buffer.complete) }
}

/// Returns a slice containing the entire array. Equivalent to `&s[..]`.
Expand Down
57 changes: 42 additions & 15 deletions library/core/tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,26 +323,53 @@ fn array_map() {
fn array_map_drop_safety() {
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering;
static DROPPED: AtomicUsize = AtomicUsize::new(0);
struct DropCounter;
impl Drop for DropCounter {
static DROPPED: [AtomicUsize; 3] =
[AtomicUsize::new(0), AtomicUsize::new(0), AtomicUsize::new(0)];
struct DropCounter<const N: usize>;
impl<const N: usize> Drop for DropCounter<N> {
fn drop(&mut self) {
DROPPED.fetch_add(1, Ordering::SeqCst);
DROPPED[N].fetch_add(1, Ordering::SeqCst);
}
}

let num_to_create = 5;
let success = std::panic::catch_unwind(|| {
let items = [0; 10];
let mut nth = 0;
items.map(|_| {
assert!(nth < num_to_create);
nth += 1;
DropCounter
{
let num_to_create = 5;
let success = std::panic::catch_unwind(|| {
let items = [0; 10];
let mut nth = 0;
items.map(|_| {
assert!(nth < num_to_create);
nth += 1;
DropCounter::<0>
});
});
assert!(success.is_err());
assert_eq!(DROPPED[0].load(Ordering::SeqCst), num_to_create);
}

{
assert_eq!(DROPPED[1].load(Ordering::SeqCst), 0);
let num_to_create = 3;
const TOTAL: usize = 5;
let success = std::panic::catch_unwind(|| {
let items: [DropCounter<1>; TOTAL] = [
DropCounter::<1>,
DropCounter::<1>,
DropCounter::<1>,
DropCounter::<1>,
DropCounter::<1>,
];
let mut nth = 0;
items.map(|_| {
assert!(nth < num_to_create);
nth += 1;
DropCounter::<2>
});
});
});
assert!(success.is_err());
assert_eq!(DROPPED.load(Ordering::SeqCst), num_to_create);
assert!(success.is_err());
assert_eq!(DROPPED[2].load(Ordering::SeqCst), num_to_create);
assert_eq!(DROPPED[1].load(Ordering::SeqCst), TOTAL);
}
panic!("test succeeded")
}

Expand Down
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#![feature(raw)]
#![feature(sort_internals)]
#![feature(slice_partition_at_index)]
#![feature(min_const_generics)]
#![feature(min_specialization)]
#![feature(step_trait)]
#![feature(step_trait_ext)]
Expand Down
25 changes: 25 additions & 0 deletions src/test/codegen/array-map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// compile-flags: -C opt-level=3 -Zmir-opt-level=3
#![crate_type = "lib"]
#![feature(array_map)]

const SIZE: usize = 4;

// CHECK-LABEL: @array_cast_to_float
#[no_mangle]
pub fn array_cast_to_float(x: [u32; SIZE]) -> [f32; SIZE] {
// CHECK: cast
// CHECK: @llvm.memcpy
// CHECK: ret
// CHECK-EMPTY
x.map(|v| v as f32)
}

// CHECK-LABEL: @array_cast_to_u64
#[no_mangle]
pub fn array_cast_to_u64(x: [u32; SIZE]) -> [u64; SIZE] {
// CHECK: cast
// CHECK: @llvm.memcpy
// CHECK: ret
// CHECK-EMPTY
x.map(|v| v as u64)
}