Skip to content

Commit 932f981

Browse files
committed
Redo the swap code for better tail & padding handling
1 parent 4e5fec2 commit 932f981

File tree

12 files changed

+443
-135
lines changed

12 files changed

+443
-135
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

+17
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,23 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
498498
}
499499
}
500500

501+
sym::untyped_swap_nonoverlapping => {
502+
// The fallback impl uses memcpy, which leaves around allocas
503+
// that don't optimize out for certain widths, so force it to
504+
// use SSA registers instead.
505+
506+
let chunk_ty = fn_args.type_at(0);
507+
let layout = self.layout_of(chunk_ty).layout;
508+
let integer_ty = self.type_ix(layout.size().bits());
509+
let a = args[0].immediate();
510+
let b = args[1].immediate();
511+
let a_val = self.load(integer_ty, a, layout.align().abi);
512+
let b_val = self.load(integer_ty, b, layout.align().abi);
513+
self.store(b_val, a, layout.align().abi);
514+
self.store(a_val, b, layout.align().abi);
515+
return Ok(());
516+
}
517+
501518
sym::compare_bytes => {
502519
// Here we assume that the `memcmp` provided by the target is a NOP for size 0.
503520
let cmp = self.call_intrinsic("memcmp", &[

compiler/rustc_hir_analysis/src/check/intrinsic.rs

+6
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,12 @@ pub fn check_intrinsic_type(
504504
sym::typed_swap_nonoverlapping => {
505505
(1, 0, vec![Ty::new_mut_ptr(tcx, param(0)); 2], tcx.types.unit)
506506
}
507+
sym::untyped_swap_nonoverlapping => (
508+
1,
509+
0,
510+
vec![Ty::new_mut_ptr(tcx, Ty::new_maybe_uninit(tcx, param(0))); 2],
511+
tcx.types.unit,
512+
),
507513

508514
sym::discriminant_value => {
509515
let assoc_items = tcx.associated_item_def_ids(

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2142,6 +2142,7 @@ symbols! {
21422142
unstable location; did you mean to load this crate \
21432143
from crates.io via `Cargo.toml` instead?",
21442144
untagged_unions,
2145+
untyped_swap_nonoverlapping,
21452146
unused_imports,
21462147
unwind,
21472148
unwind_attributes,

library/core/src/intrinsics/mod.rs

+32-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
use crate::marker::{DiscriminantKind, Tuple};
6868
use crate::mem::SizedTypeProperties;
69-
use crate::{ptr, ub_checks};
69+
use crate::{mem, ptr, ub_checks};
7070

7171
pub mod fallback;
7272
pub mod mir;
@@ -4003,7 +4003,37 @@ pub use typed_swap as typed_swap_nonoverlapping;
40034003
pub const unsafe fn typed_swap_nonoverlapping<T>(x: *mut T, y: *mut T) {
40044004
// SAFETY: The caller provided single non-overlapping items behind
40054005
// pointers, so swapping them with `count: 1` is fine.
4006-
unsafe { ptr::swap_nonoverlapping(x, y, 1) };
4006+
unsafe { crate::swapping::swap_nonoverlapping(x, y, 1) };
4007+
}
4008+
4009+
/// Swaps the `N` untyped & non-overlapping bytes behind the two pointers.
4010+
///
4011+
/// Split out from `typed_swap` for the internal swaps in `swap_nonoverlapping`
4012+
/// which would otherwise cause cycles between the fallback implementations on
4013+
/// backends where neither is overridden.
4014+
///
4015+
/// # Safety
4016+
///
4017+
/// `x` and `y` are readable and writable as `MaybeUninit<C>` and non-overlapping.
4018+
#[inline]
4019+
#[rustc_nounwind]
4020+
#[cfg_attr(not(bootstrap), rustc_intrinsic)]
4021+
#[miri::intrinsic_fallback_is_spec]
4022+
#[rustc_const_stable_indirect]
4023+
pub const unsafe fn untyped_swap_nonoverlapping<C>(
4024+
x: *mut mem::MaybeUninit<C>,
4025+
y: *mut mem::MaybeUninit<C>,
4026+
) {
4027+
// This intentionally uses untyped memory copies, not reads/writes,
4028+
// to avoid any risk of losing padding in things like (u16, u8).
4029+
let mut temp = mem::MaybeUninit::<C>::uninit();
4030+
// SAFETY: Caller promised that x and y are non-overlapping & read/writeable,
4031+
// and our fresh local is always disjoint from anything otherwise readable.
4032+
unsafe {
4033+
(&raw mut temp).copy_from_nonoverlapping(x, 1);
4034+
x.copy_from_nonoverlapping(y, 1);
4035+
y.copy_from_nonoverlapping(&raw const temp, 1);
4036+
}
40074037
}
40084038

40094039
/// Returns whether we should perform some UB-checking at runtime. This eventually evaluates to

library/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ pub mod alloc;
376376
// note: does not need to be public
377377
mod bool;
378378
mod escape;
379+
pub(crate) mod swapping;
379380
mod tuple;
380381
mod unit;
381382

library/core/src/ptr/mod.rs

+2-79
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@
395395
#![allow(clippy::not_unsafe_ptr_arg_deref)]
396396

397397
use crate::cmp::Ordering;
398-
use crate::intrinsics::const_eval_select;
399398
use crate::marker::FnPtr;
400399
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
401400
use crate::{fmt, hash, intrinsics, ub_checks};
@@ -1092,84 +1091,8 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
10921091
}
10931092
);
10941093

1095-
const_eval_select!(
1096-
@capture[T] { x: *mut T, y: *mut T, count: usize }:
1097-
if const {
1098-
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
1099-
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
1100-
// of a pointer (which would not work).
1101-
// SAFETY: Same preconditions as this function
1102-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
1103-
} else {
1104-
macro_rules! attempt_swap_as_chunks {
1105-
($ChunkTy:ty) => {
1106-
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
1107-
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
1108-
{
1109-
let x: *mut $ChunkTy = x.cast();
1110-
let y: *mut $ChunkTy = y.cast();
1111-
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
1112-
// SAFETY: these are the same bytes that the caller promised were
1113-
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
1114-
// The `if` condition above ensures that we're not violating
1115-
// alignment requirements, and that the division is exact so
1116-
// that we don't lose any bytes off the end.
1117-
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
1118-
}
1119-
};
1120-
}
1121-
1122-
// Split up the slice into small power-of-two-sized chunks that LLVM is able
1123-
// to vectorize (unless it's a special type with more-than-pointer alignment,
1124-
// because we don't want to pessimize things like slices of SIMD vectors.)
1125-
if mem::align_of::<T>() <= mem::size_of::<usize>()
1126-
&& (!mem::size_of::<T>().is_power_of_two()
1127-
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
1128-
{
1129-
attempt_swap_as_chunks!(usize);
1130-
attempt_swap_as_chunks!(u8);
1131-
}
1132-
1133-
// SAFETY: Same preconditions as this function
1134-
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
1135-
}
1136-
)
1137-
}
1138-
1139-
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
1140-
///
1141-
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
1142-
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
1143-
#[inline]
1144-
const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, count: usize) {
1145-
let x = x.cast::<MaybeUninit<T>>();
1146-
let y = y.cast::<MaybeUninit<T>>();
1147-
let mut i = 0;
1148-
while i < count {
1149-
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
1150-
let x = unsafe { x.add(i) };
1151-
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
1152-
// and it's distinct from `x` since the ranges are non-overlapping
1153-
let y = unsafe { y.add(i) };
1154-
1155-
// If we end up here, it's because we're using a simple type -- like
1156-
// a small power-of-two-sized thing -- or a special type with particularly
1157-
// large alignment, particularly SIMD types.
1158-
// Thus, we're fine just reading-and-writing it, as either it's small
1159-
// and that works well anyway or it's special and the type's author
1160-
// presumably wanted things to be done in the larger chunk.
1161-
1162-
// SAFETY: we're only ever given pointers that are valid to read/write,
1163-
// including being aligned, and nothing here panics so it's drop-safe.
1164-
unsafe {
1165-
let a: MaybeUninit<T> = read(x);
1166-
let b: MaybeUninit<T> = read(y);
1167-
write(x, b);
1168-
write(y, a);
1169-
}
1170-
1171-
i += 1;
1172-
}
1094+
// SAFETY: Same preconditions as this function
1095+
unsafe { crate::swapping::swap_nonoverlapping(x, y, count) }
11731096
}
11741097

11751098
/// Moves `src` into the pointed `dst`, returning the previous `dst` value.

library/core/src/swapping.rs

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
use crate::{hint, intrinsics, mem, ptr};
2+
3+
//#[rustc_const_stable_indirect]
4+
//#[rustc_allow_const_fn_unstable(const_eval_select)]
5+
#[rustc_const_unstable(feature = "const_swap_nonoverlapping", issue = "133668")]
6+
#[inline]
7+
pub(crate) const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
8+
intrinsics::const_eval_select!(
9+
@capture[T] { x: *mut T, y: *mut T, count: usize }:
10+
if const {
11+
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
12+
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
13+
// of a pointer (which would not work).
14+
// SAFETY: Same preconditions as this function
15+
unsafe { swap_nonoverlapping_const(x, y, count) }
16+
} else {
17+
// At runtime we want to make sure not to swap byte-for-byte for types like [u8; 15],
18+
// and swapping as `MaybeUninit<T>` doesn't actually work as untyped for things like
19+
// T = (u16, u8), so we type-erase to raw bytes and swap that way.
20+
// SAFETY: Same preconditions as this function
21+
unsafe { swap_nonoverlapping_runtime(x, y, count) }
22+
}
23+
)
24+
}
25+
26+
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
27+
#[rustc_const_stable_indirect]
28+
#[inline]
29+
const unsafe fn swap_nonoverlapping_const<T>(x: *mut T, y: *mut T, count: usize) {
30+
let x = x.cast::<mem::MaybeUninit<T>>();
31+
let y = y.cast::<mem::MaybeUninit<T>>();
32+
let mut i = 0;
33+
while i < count {
34+
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
35+
// and because the two input ranges are non-overlapping and read/writeable,
36+
// these individual items inside them are too.
37+
unsafe {
38+
intrinsics::untyped_swap_nonoverlapping::<T>(x.add(i), y.add(i));
39+
}
40+
41+
i += 1;
42+
}
43+
}
44+
45+
// Scale the monomorphizations with the size of the machine, roughly.
46+
const MAX_ALIGN: usize = align_of::<usize>().pow(2);
47+
48+
/// Same behavior and safety conditions as [`swap_nonoverlapping`]
49+
#[inline]
50+
unsafe fn swap_nonoverlapping_runtime<T>(x: *mut T, y: *mut T, count: usize) {
51+
let bytes = {
52+
let slice = ptr::slice_from_raw_parts(x, count);
53+
// SAFETY: Because they both exist in memory and don't overlap, they
54+
// must be legal slice sizes (below `isize::MAX` bytes).
55+
unsafe { mem::size_of_val_raw(slice) }
56+
};
57+
58+
// Generating *untyped* loops for every type is silly, so we polymorphize away
59+
// the actual type, but we want to take advantage of alignment if possible,
60+
// so monomorphize for a restricted set of possible alignments.
61+
macro_rules! delegate_by_alignment {
62+
($($p:pat => $align:expr,)+) => {{
63+
#![allow(unreachable_patterns)]
64+
match const { align_of::<T>() } {
65+
$(
66+
$p => {
67+
swap_nonoverlapping_bytes::<$align>(x.cast(), y.cast(), bytes);
68+
}
69+
)+
70+
}
71+
}};
72+
}
73+
74+
// SAFETY:
75+
unsafe {
76+
delegate_by_alignment! {
77+
MAX_ALIGN.. => MAX_ALIGN,
78+
64.. => 64,
79+
32.. => 32,
80+
16.. => 16,
81+
8.. => 8,
82+
4.. => 4,
83+
2.. => 2,
84+
_ => 1,
85+
}
86+
}
87+
}
88+
89+
/// # Safety:
90+
/// - `x` and `y` must be aligned to `ALIGN`
91+
/// - `bytes` must be a multiple of `ALIGN`
92+
/// - They must be readable, writable, and non-overlapping for `bytes` bytes
93+
#[inline]
94+
unsafe fn swap_nonoverlapping_bytes<const ALIGN: usize>(
95+
x: *mut mem::MaybeUninit<u8>,
96+
y: *mut mem::MaybeUninit<u8>,
97+
bytes: usize,
98+
) {
99+
// SAFETY: Two legal non-overlapping regions can't be bigger than this.
100+
// (And they couldn't have made allocations any bigger either anyway.)
101+
// FIXME: Would be nice to have a type for this instead of the assume.
102+
unsafe { hint::assert_unchecked(bytes < isize::MAX as usize) };
103+
104+
let mut i = 0;
105+
macro_rules! swap_next_n {
106+
($n:expr) => {{
107+
let x: *mut mem::MaybeUninit<[u8; $n]> = x.add(i).cast();
108+
let y: *mut mem::MaybeUninit<[u8; $n]> = y.add(i).cast();
109+
swap_nonoverlapping_aligned_chunk::<ALIGN, [u8; $n]>(
110+
x.as_mut_unchecked(),
111+
y.as_mut_unchecked(),
112+
);
113+
i += $n;
114+
}};
115+
}
116+
117+
while bytes - i >= MAX_ALIGN {
118+
const { assert!(MAX_ALIGN >= ALIGN) };
119+
// SAFETY: the const-assert above confirms we're only ever called with
120+
// an alignment equal to or smaller than max align, so this is necessarily
121+
// aligned, and the while loop ensures there's enough read/write memory.
122+
unsafe {
123+
swap_next_n!(MAX_ALIGN);
124+
}
125+
}
126+
127+
macro_rules! handle_tail {
128+
($($n:literal)+) => {$(
129+
if const { $n % ALIGN == 0 } {
130+
// Checking this way simplifies the block end to just add+test,
131+
// rather than needing extra math before the check.
132+
if (bytes & $n) != 0 {
133+
// SAFETY: The above swaps were bigger, so could not have
134+
// impacted the `$n`-relevant bit, so checking `bytes & $n`
135+
// was equivalent to `bytes - i >= $n`, and thus we have
136+
// enough space left to swap another `$n` bytes.
137+
unsafe {
138+
swap_next_n!($n);
139+
}
140+
}
141+
}
142+
)+};
143+
}
144+
const { assert!(MAX_ALIGN <= 64) };
145+
handle_tail!(32 16 8 4 2 1);
146+
147+
debug_assert_eq!(i, bytes);
148+
}
149+
150+
/// Swaps the `C` behind `x` and `y` as untyped memory
151+
///
152+
/// # Safety
153+
///
154+
/// Both `x` and `y` must be aligned to `ALIGN`, in addition to their normal alignment.
155+
/// They must be readable and writeable for `sizeof(C)` bytes, as usual for `&mut`s.
156+
///
157+
/// (The actual instantiations are usually `C = [u8; _]`, so we get the alignment
158+
/// information from the loads by `assume`ing the passed-in alignment.)
159+
// Don't let MIR inline this, because we really want it to keep its noalias metadata
160+
#[rustc_no_mir_inline]
161+
#[inline]
162+
unsafe fn swap_nonoverlapping_aligned_chunk<const ALIGN: usize, C>(
163+
x: &mut mem::MaybeUninit<C>,
164+
y: &mut mem::MaybeUninit<C>,
165+
) {
166+
assert!(size_of::<C>() % ALIGN == 0);
167+
168+
let x = ptr::from_mut(x);
169+
let y = ptr::from_mut(y);
170+
171+
// SAFETY: One of our preconditions.
172+
unsafe {
173+
hint::assert_unchecked(x.is_aligned_to(ALIGN));
174+
hint::assert_unchecked(y.is_aligned_to(ALIGN));
175+
}
176+
177+
// SAFETY: The memory is readable and writable because these were passed to
178+
// us as mutable references, and the untyped swap doesn't need validity.
179+
unsafe {
180+
intrinsics::untyped_swap_nonoverlapping::<C>(x, y);
181+
}
182+
}

0 commit comments

Comments
 (0)