Skip to content

Commit 11a44ad

Browse files
committed
Auto merge of #75600 - nagisa:improve_align_offset, r=KodrAus
Improve codegen for `align_offset` In this PR the `align_offset` implementation is changed/improved to produce better code in certain scenarios such as when pointer type is has a stride of 1 or when building for low optimisation levels. While these changes do not achieve the "ideal" codegen referenced in #75579, it gets significantly closer to it. I’m not actually sure if the codegen can actually be much better with this function returning the offset, rather than the aligned pointer. See the descriptions for separate commits for further information.
2 parents 5b04bbf + 5d22b18 commit 11a44ad

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

library/core/src/ptr/mod.rs

+47-28
Original file line numberDiff line numberDiff line change
@@ -1166,16 +1166,20 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
11661166
/// Any questions go to @nagisa.
11671167
#[lang = "align_offset"]
11681168
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
1169+
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
1170+
// 1, where the method versions of these operations are not inlined.
1171+
use intrinsics::{unchecked_shl, unchecked_shr, unchecked_sub, wrapping_mul, wrapping_sub};
1172+
11691173
/// Calculate multiplicative modular inverse of `x` modulo `m`.
11701174
///
1171-
/// This implementation is tailored for align_offset and has following preconditions:
1175+
/// This implementation is tailored for `align_offset` and has following preconditions:
11721176
///
11731177
/// * `m` is a power-of-two;
11741178
/// * `x < m`; (if `x ≥ m`, pass in `x % m` instead)
11751179
///
11761180
/// Implementation of this function shall not panic. Ever.
11771181
#[inline]
1178-
fn mod_inv(x: usize, m: usize) -> usize {
1182+
unsafe fn mod_inv(x: usize, m: usize) -> usize {
11791183
/// Multiplicative modular inverse table modulo 2⁴ = 16.
11801184
///
11811185
/// Note, that this table does not contain values where inverse does not exist (i.e., for
@@ -1187,8 +1191,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
11871191
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;
11881192

11891193
let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
1194+
// SAFETY: `m` is required to be a power-of-two, hence non-zero.
1195+
let m_minus_one = unsafe { unchecked_sub(m, 1) };
11901196
if m <= INV_TABLE_MOD {
1191-
table_inverse & (m - 1)
1197+
table_inverse & m_minus_one
11921198
} else {
11931199
// We iterate "up" using the following formula:
11941200
//
@@ -1204,49 +1210,50 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
12041210
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
12051211
// usize::MAX` instead, because we take the result `mod n` at the end
12061212
// anyway.
1207-
inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)));
1213+
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
12081214
if going_mod >= m {
1209-
return inverse & (m - 1);
1215+
return inverse & m_minus_one;
12101216
}
1211-
going_mod = going_mod.wrapping_mul(going_mod);
1217+
going_mod = wrapping_mul(going_mod, going_mod);
12121218
}
12131219
}
12141220
}
12151221

12161222
let stride = mem::size_of::<T>();
1217-
let a_minus_one = a.wrapping_sub(1);
1218-
let pmoda = p as usize & a_minus_one;
1223+
// SAFETY: `a` is a power-of-two, therefore non-zero.
1224+
let a_minus_one = unsafe { unchecked_sub(a, 1) };
1225+
if stride == 1 {
1226+
// `stride == 1` case can be computed more efficiently through `-p (mod a)`.
1227+
return wrapping_sub(0, p as usize) & a_minus_one;
1228+
}
12191229

1230+
let pmoda = p as usize & a_minus_one;
12201231
if pmoda == 0 {
12211232
// Already aligned. Yay!
12221233
return 0;
1223-
}
1224-
1225-
if stride <= 1 {
1226-
return if stride == 0 {
1227-
// If the pointer is not aligned, and the element is zero-sized, then no amount of
1228-
// elements will ever align the pointer.
1229-
!0
1230-
} else {
1231-
a.wrapping_sub(pmoda)
1232-
};
1234+
} else if stride == 0 {
1235+
// If the pointer is not aligned, and the element is zero-sized, then no amount of
1236+
// elements will ever align the pointer.
1237+
return usize::MAX;
12331238
}
12341239

12351240
let smoda = stride & a_minus_one;
1236-
// SAFETY: a is power-of-two so cannot be 0. stride = 0 is handled above.
1241+
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
12371242
let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
1238-
let gcd = 1usize << gcdpow;
1243+
// SAFETY: gcdpow has an upper-bound that’s at most the number of bits in an usize.
1244+
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
12391245

1240-
if p as usize & (gcd.wrapping_sub(1)) == 0 {
1246+
// SAFETY: gcd is always greater or equal to 1.
1247+
if p as usize & unsafe { unchecked_sub(gcd, 1) } == 0 {
12411248
// This branch solves for the following linear congruence equation:
12421249
//
12431250
// ` p + so = 0 mod a `
12441251
//
12451252
// `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
12461253
// requested alignment.
12471254
//
1248-
// With `g = gcd(a, s)`, and the above asserting that `p` is also divisible by `g`, we can
1249-
// denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
1255+
// With `g = gcd(a, s)`, and the above condition asserting that `p` is also divisible by
1256+
// `g`, we can denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
12501257
//
12511258
// ` p' + s'o = 0 mod a' `
12521259
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
@@ -1259,11 +1266,23 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
12591266
//
12601267
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
12611268
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
1262-
let a2 = a >> gcdpow;
1263-
let a2minus1 = a2.wrapping_sub(1);
1264-
let s2 = smoda >> gcdpow;
1265-
let minusp2 = a2.wrapping_sub(pmoda >> gcdpow);
1266-
return (minusp2.wrapping_mul(mod_inv(s2, a2))) & a2minus1;
1269+
1270+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1271+
// `a`.
1272+
let a2 = unsafe { unchecked_shr(a, gcdpow) };
1273+
// SAFETY: `a2` is non-zero. Shifting `a` by `gcdpow` cannot shift out any of the set bits
1274+
// in `a` (of which it has exactly one).
1275+
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
1276+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1277+
// `a`.
1278+
let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
1279+
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
1280+
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
1281+
// always be strictly greater than `(p % a) >> gcdpow`.
1282+
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
1283+
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
1284+
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
1285+
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
12671286
}
12681287

12691288
// Cannot be aligned at all.

0 commit comments

Comments
 (0)