Skip to content

Commit 98ab2de

Browse files
committed
Auto merge of rust-lang#139729 - scottmcm:more-enum-tweaks, r=<try>
Allow matching on 3+ variant niche-encoded enums to optimize better While the two-variant case is most common (and already special-cased), it's pretty unusual to actually need the *fully-general* niche-decoding algorithm (that handles things like 200+ variants wrapping the encoding space and such). Layout puts the niche-encoded variants on one end of the natural values, so because enums don't have that many variants, it's quite common that there's no wrapping because the handful of variants just end up after the end of the `bool` or `char` or `newtype_index!` or whatever. This PR thus looks for those cases: situations where the tag's range doesn't actually wrap, and thus we can check for niche-vs-untag in one simple `icmp` without needing to adjust the tag value, and by picking between zero- and sign-extension based on *which* kind of non-wrapping it is, also help LLVM better understand by not forcing it to think about wrapping arithmetic either. It also emits the operations in a more optimization-friendly order. While the MIR Rvalue calculates a discriminant, so that's what we emit, code normally doesn't actually care about the actual discriminant for these niche-encoded enums. Rather, the discriminant is just getting passed to an equality check (for something like `matches!(foo, TerminatorKind::Goto { .. }`) or a `SwitchInt` (when it's being matched on). So while the old code would emit, roughly ```rust if is_niche { tag + ADJUSTMENT } else { UNTAGGED_DISCR } ``` this PR changes it instead to ```rust (if is_niche { tag } else { UNTAGGED_ADJ_DISCR }) + ADJUSTMENT ``` which on its own might seem odd, but it's actually easier to optimize because what we're actually doing is ```rust complicated_stuff() + ADJUSTMENT == 4 ``` or ```rust match complicated_stuff() + ADJUSTMENT { 0 =>…, 1 => …, 2 => …, _ => unreachable } ``` or in the generated `PartialEq` for enums with fieldless variants, ```rust complicated_stuff(a) + ADJUSTMENT == complicated_stuff(b) + ADJUSTMENT ``` and thus that's easy for the optimizer to eliminate the additions: ```rust complicated_stuff() == 2 ``` ```rust match complicated_stuff() { 7 => …, 8 => …, 9 => …, _ => unreachable } ``` ```rust complicated_stuff(a) == complicated_stuff(b) ``` For good measure I went and made sure that cranelift can do this optimization too 🙂 bytecodealliance/wasmtime#10489 r? WaffleLapkin Follow-up to rust-lang#139098 -- EDIT later: I happened to notice rust-lang#110197 (comment) -- it looks like there used to be some optimizations in this code, but they got removed for being wrong. I've added lots of tests here; let's hope I can avoid that fate 😬 (Certainly it would be possible to save some complexity by restricting this to the easy case, where it's unsigned-nowrap, the niches are after the natural payload, and all the variant indexes are small.)
2 parents 65fa0ab + 66ddcbf commit 98ab2de

File tree

6 files changed

+1029
-72
lines changed

6 files changed

+1029
-72
lines changed

compiler/rustc_abi/src/lib.rs

+34-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use std::fmt;
4343
#[cfg(feature = "nightly")]
4444
use std::iter::Step;
4545
use std::num::{NonZeroUsize, ParseIntError};
46-
use std::ops::{Add, AddAssign, Mul, RangeInclusive, Sub};
46+
use std::ops::{Add, AddAssign, Mul, RangeFull, RangeInclusive, Sub};
4747
use std::str::FromStr;
4848

4949
use bitflags::bitflags;
@@ -1162,12 +1162,45 @@ impl WrappingRange {
11621162
}
11631163

11641164
/// Returns `true` if `size` completely fills the range.
1165+
///
1166+
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
1167+
/// Niche calculations can produce full ranges which are not the canonical one;
1168+
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
11651169
#[inline]
11661170
fn is_full_for(&self, size: Size) -> bool {
11671171
let max_value = size.unsigned_int_max();
11681172
debug_assert!(self.start <= max_value && self.end <= max_value);
11691173
self.start == (self.end.wrapping_add(1) & max_value)
11701174
}
1175+
1176+
/// Checks whether this range is considered non-wrapping when the values are
1177+
/// interpreted as *unsigned* numbers of width `size`.
1178+
///
1179+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1180+
/// and `Err(..)` if the range is full so it depends how you think about it.
1181+
#[inline]
1182+
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1183+
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
1184+
}
1185+
1186+
/// Checks whether this range is considered non-wrapping when the values are
1187+
/// interpreted as *signed* numbers of width `size`.
1188+
///
1189+
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
1190+
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
1191+
///
1192+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1193+
/// and `Err(..)` if the range is full so it depends how you think about it.
1194+
#[inline]
1195+
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1196+
if self.is_full_for(size) {
1197+
Err(..)
1198+
} else {
1199+
let start: i128 = size.sign_extend(self.start);
1200+
let end: i128 = size.sign_extend(self.end);
1201+
Ok(start <= end)
1202+
}
1203+
}
11711204
}
11721205

11731206
impl fmt::Debug for WrappingRange {

compiler/rustc_codegen_ssa/src/mir/operand.rs

+181-41
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use std::fmt;
33
use arrayvec::ArrayVec;
44
use either::Either;
55
use rustc_abi as abi;
6-
use rustc_abi::{Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, Variants};
6+
use rustc_abi::{
7+
Align, BackendRepr, FIRST_VARIANT, Primitive, Size, TagEncoding, VariantIdx, Variants,
8+
};
79
use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
810
use rustc_middle::mir::{self, ConstValue};
911
use rustc_middle::ty::Ty;
@@ -510,6 +512,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
510512
);
511513

512514
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
515+
let tag_range = tag_scalar.valid_range(&dl);
516+
let tag_size = tag_scalar.size(&dl);
513517

514518
// We have a subrange `niche_start..=niche_end` inside `range`.
515519
// If the value of the tag is inside this subrange, it's a
@@ -525,53 +529,189 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
525529
// untagged_variant
526530
// }
527531
// However, we will likely be able to emit simpler code.
528-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
529-
// Best case scenario: only one tagged variant. This will
530-
// likely become just a comparison and a jump.
531-
// The algorithm is:
532-
// is_niche = tag == niche_start
533-
// discr = if is_niche {
534-
// niche_start
535-
// } else {
536-
// untagged_variant
537-
// }
532+
533+
// First, the incredibly-common case of a two-variant enum (like
534+
// `Option` or `Result`) where we only need one check.
535+
if relative_max == 0 {
538536
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
539-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
540-
let tagged_discr =
541-
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
542-
(is_niche, tagged_discr, 0)
543-
} else {
544-
// The special cases don't apply, so we'll have to go with
545-
// the general algorithm.
546-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
547-
let cast_tag = bx.intcast(relative_discr, cast_to, false);
548-
let is_niche = bx.icmp(
549-
IntPredicate::IntULE,
550-
relative_discr,
551-
bx.cx().const_uint(tag_llty, relative_max as u64),
552-
);
553-
554-
// Thanks to parameter attributes and load metadata, LLVM already knows
555-
// the general valid range of the tag. It's possible, though, for there
556-
// to be an impossible value *in the middle*, which those ranges don't
557-
// communicate, so it's worth an `assume` to let the optimizer know.
558-
if niche_variants.contains(&untagged_variant)
559-
&& bx.cx().sess().opts.optimize != OptLevel::No
537+
let is_natural = bx.icmp(IntPredicate::IntNE, tag, niche_start);
538+
return if untagged_variant == VariantIdx::from_u32(1)
539+
&& *niche_variants.start() == VariantIdx::from_u32(0)
560540
{
561-
let impossible =
562-
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
563-
let impossible = bx.cx().const_uint(tag_llty, impossible);
564-
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
565-
bx.assume(ne);
541+
// The polarity of the comparison above is picked so we can
542+
// just extend for `Option<T>`, which has these variants.
543+
bx.zext(is_natural, cast_to)
544+
} else {
545+
let tagged_discr =
546+
bx.cx().const_uint(cast_to, u64::from(niche_variants.start().as_u32()));
547+
let untagged_discr =
548+
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
549+
bx.select(is_natural, untagged_discr, tagged_discr)
550+
};
551+
}
552+
553+
let niche_end =
554+
tag_size.truncate(u128::from(relative_max).wrapping_add(niche_start));
555+
556+
// Next, the layout algorithm prefers to put the niches at one end,
557+
// so look for cases where we don't need to calculate a relative_tag
558+
// at all and can just look at the original tag value directly.
559+
// This also lets us move any possibly-wrapping addition to the end
560+
// where it's easiest to get rid of in the normal uses: it's easy
561+
// to optimize `COMPLICATED + 2 == 7` to `COMPLICATED == (7 - 2)`.
562+
{
563+
// Work in whichever size is wider, because it's possible for
564+
// the untagged variant to be further away from the niches than
565+
// is possible to represent in the smaller type.
566+
let (wide_size, wide_ibty) = if cast_to_layout.size > tag_size {
567+
(cast_to_layout.size, cast_to)
568+
} else {
569+
(tag_size, tag_llty)
570+
};
571+
572+
struct NoWrapData<V> {
573+
wide_tag: V,
574+
is_niche: V,
575+
needs_assume: bool,
576+
wide_niche_to_variant: u128,
577+
wide_niche_untagged: u128,
566578
}
567579

568-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
569-
};
580+
let first_variant = u128::from(niche_variants.start().as_u32());
581+
let untagged_variant = u128::from(untagged_variant.as_u32());
582+
583+
let opt_data = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
584+
let wide_tag = bx.zext(tag, wide_ibty);
585+
let extend = |x| x;
586+
let wide_niche_start = extend(niche_start);
587+
let wide_niche_end = extend(niche_end);
588+
debug_assert!(wide_niche_start <= wide_niche_end);
589+
let wide_first_variant = extend(first_variant);
590+
let wide_untagged_variant = extend(untagged_variant);
591+
let wide_niche_to_variant =
592+
wide_first_variant.wrapping_sub(wide_niche_start);
593+
let wide_niche_untagged = wide_size
594+
.truncate(wide_untagged_variant.wrapping_sub(wide_niche_to_variant));
595+
let (is_niche, needs_assume) = if tag_range.start == niche_start {
596+
let end = bx.cx().const_uint_big(tag_llty, niche_end);
597+
(
598+
bx.icmp(IntPredicate::IntULE, tag, end),
599+
wide_niche_untagged <= wide_niche_end,
600+
)
601+
} else if tag_range.end == niche_end {
602+
let start = bx.cx().const_uint_big(tag_llty, niche_start);
603+
(
604+
bx.icmp(IntPredicate::IntUGE, tag, start),
605+
wide_niche_untagged >= wide_niche_start,
606+
)
607+
} else {
608+
bug!()
609+
};
610+
Some(NoWrapData {
611+
wide_tag,
612+
is_niche,
613+
needs_assume,
614+
wide_niche_to_variant,
615+
wide_niche_untagged,
616+
})
617+
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
618+
let wide_tag = bx.sext(tag, wide_ibty);
619+
let extend = |x| tag_size.sign_extend(x);
620+
let wide_niche_start = extend(niche_start);
621+
let wide_niche_end = extend(niche_end);
622+
debug_assert!(wide_niche_start <= wide_niche_end);
623+
let wide_first_variant = extend(first_variant);
624+
let wide_untagged_variant = extend(untagged_variant);
625+
let wide_niche_to_variant =
626+
wide_first_variant.wrapping_sub(wide_niche_start);
627+
let wide_niche_untagged = wide_size.sign_extend(
628+
wide_untagged_variant
629+
.wrapping_sub(wide_niche_to_variant)
630+
.cast_unsigned(),
631+
);
632+
let (is_niche, needs_assume) = if tag_range.start == niche_start {
633+
let end = bx.cx().const_uint_big(tag_llty, niche_end);
634+
(
635+
bx.icmp(IntPredicate::IntSLE, tag, end),
636+
wide_niche_untagged <= wide_niche_end,
637+
)
638+
} else if tag_range.end == niche_end {
639+
let start = bx.cx().const_uint_big(tag_llty, niche_start);
640+
(
641+
bx.icmp(IntPredicate::IntSGE, tag, start),
642+
wide_niche_untagged >= wide_niche_start,
643+
)
644+
} else {
645+
bug!()
646+
};
647+
Some(NoWrapData {
648+
wide_tag,
649+
is_niche,
650+
needs_assume,
651+
wide_niche_to_variant: wide_niche_to_variant.cast_unsigned(),
652+
wide_niche_untagged: wide_niche_untagged.cast_unsigned(),
653+
})
654+
} else {
655+
None
656+
};
657+
if let Some(NoWrapData {
658+
wide_tag,
659+
is_niche,
660+
needs_assume,
661+
wide_niche_to_variant,
662+
wide_niche_untagged,
663+
}) = opt_data
664+
{
665+
let wide_niche_untagged =
666+
bx.cx().const_uint_big(wide_ibty, wide_niche_untagged);
667+
if needs_assume && bx.cx().sess().opts.optimize != OptLevel::No {
668+
let not_untagged =
669+
bx.icmp(IntPredicate::IntNE, wide_tag, wide_niche_untagged);
670+
bx.assume(not_untagged);
671+
}
672+
673+
let wide_niche = bx.select(is_niche, wide_tag, wide_niche_untagged);
674+
let cast_niche = bx.trunc(wide_niche, cast_to);
675+
let discr = if wide_niche_to_variant == 0 {
676+
cast_niche
677+
} else {
678+
let niche_to_variant =
679+
bx.cx().const_uint_big(cast_to, wide_niche_to_variant);
680+
bx.add(cast_niche, niche_to_variant)
681+
};
682+
return discr;
683+
}
684+
}
685+
686+
// Otherwise the special cases don't apply,
687+
// so we'll have to go with the general algorithm.
688+
let relative_tag = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
689+
let relative_discr = bx.intcast(relative_tag, cast_to, false);
690+
let is_niche = bx.icmp(
691+
IntPredicate::IntULE,
692+
relative_tag,
693+
bx.cx().const_uint(tag_llty, u64::from(relative_max)),
694+
);
695+
696+
// Thanks to parameter attributes and load metadata, LLVM already knows
697+
// the general valid range of the tag. It's possible, though, for there
698+
// to be an impossible value *in the middle*, which those ranges don't
699+
// communicate, so it's worth an `assume` to let the optimizer know.
700+
if niche_variants.contains(&untagged_variant)
701+
&& bx.cx().sess().opts.optimize != OptLevel::No
702+
{
703+
let impossible =
704+
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
705+
let impossible = bx.cx().const_uint(tag_llty, impossible);
706+
let ne = bx.icmp(IntPredicate::IntNE, relative_tag, impossible);
707+
bx.assume(ne);
708+
}
570709

710+
let delta = niche_variants.start().as_u32();
571711
let tagged_discr = if delta == 0 {
572-
tagged_discr
712+
relative_discr
573713
} else {
574-
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
714+
bx.add(relative_discr, bx.cx().const_uint(cast_to, u64::from(delta)))
575715
};
576716

577717
let discr = bx.select(

0 commit comments

Comments
 (0)