diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 4adb95f85d665..cd57a077283ac 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -363,15 +363,24 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { discr: &mir::Operand<'tcx>, targets: &SwitchTargets, ) { - let discr = self.codegen_operand(bx, discr); - let discr_value = discr.immediate(); - let switch_ty = discr.layout.ty; // If our discriminant is a constant we can branch directly - if let Some(const_discr) = bx.const_to_opt_u128(discr_value, false) { + if let Some(const_op) = discr.constant() { + let const_value = self.eval_mir_constant(const_op); + let Some(const_discr) = const_value.try_to_bits_for_ty( + self.cx.tcx(), + ty::ParamEnv::reveal_all(), + const_op.ty(), + ) else { + bug!("Failed to evaluate constant {discr:?} for SwitchInt terminator") + }; let target = targets.target_for_value(const_discr); bx.br(helper.llbb_with_cleanup(self, target)); return; - }; + } + + let discr = self.codegen_operand(bx, discr); + let discr_value = discr.immediate(); + let switch_ty = discr.layout.ty; let mut target_iter = targets.iter(); if target_iter.len() == 1 { diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index ef88b253864bd..be0fb6796962f 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -711,6 +711,7 @@ impl<'tcx> Body<'tcx> { }; // If this is a SwitchInt(const _), then we can just evaluate the constant and return. + // (The `SwitchConst` transform pass tries to ensure this.) let discr = match discr { Operand::Constant(constant) => { let bits = eval_mono_const(constant)?; @@ -719,24 +720,18 @@ impl<'tcx> Body<'tcx> { Operand::Move(place) | Operand::Copy(place) => place, }; - // MIR for `if false` actually looks like this: - // _1 = const _ - // SwitchInt(_1) - // // And MIR for if intrinsics::ub_checks() looks like this: // _1 = UbChecks() // SwitchInt(_1) // // So we're going to try to recognize this pattern. // - // If we have a SwitchInt on a non-const place, we find the most recent statement that - // isn't a storage marker. If that statement is an assignment of a const to our - // discriminant place, we evaluate and return the const, as if we've const-propagated it - // into the SwitchInt. + // If we have a SwitchInt on a non-const place, we look at the last statement + // in the block. If that statement is an assignment of UbChecks to our + // discriminant place, we evaluate its value, as if we've + // const-propagated it into the SwitchInt. - let last_stmt = block.statements.iter().rev().find(|stmt| { - !matches!(stmt.kind, StatementKind::StorageDead(_) | StatementKind::StorageLive(_)) - })?; + let last_stmt = block.statements.last()?; let (place, rvalue) = last_stmt.kind.as_assign()?; @@ -746,10 +741,6 @@ impl<'tcx> Body<'tcx> { match rvalue { Rvalue::NullaryOp(NullOp::UbChecks, _) => Some((tcx.sess.ub_checks() as u128, targets)), - Rvalue::Use(Operand::Constant(constant)) => { - let bits = eval_mono_const(constant)?; - Some((bits, targets)) - } _ => None, } } diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 5d253d7384df4..8c36e63eb7d06 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -108,6 +108,7 @@ mod simplify_branches; mod simplify_comparison_integral; mod single_use_consts; mod sroa; +mod switch_const; mod unreachable_enum_branching; mod unreachable_prop; mod validate; @@ -598,6 +599,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &simplify::SimplifyLocals::AfterGVN, &dataflow_const_prop::DataflowConstProp, &single_use_consts::SingleUseConsts, + // GVN & ConstProp often don't fixup unevaluatable constants + &switch_const::SwitchConst, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), &jump_threading::JumpThreading, &early_otherwise_branch::EarlyOtherwiseBranch, diff --git a/compiler/rustc_mir_transform/src/switch_const.rs b/compiler/rustc_mir_transform/src/switch_const.rs new file mode 100644 index 0000000000000..4153e4b736dbf --- /dev/null +++ b/compiler/rustc_mir_transform/src/switch_const.rs @@ -0,0 +1,56 @@ +//! A pass that makes `SwitchInt`-on-`const` more obvious to later code. + +use rustc_middle::bug; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +/// A `MirPass` for simplifying `if T::CONST`. +/// +/// Today, MIR building for things like `if T::IS_ZST` introduce a constant +/// for the copy of the bool, so it ends up in MIR as +/// `_1 = CONST; switchInt (move _1)` or `_2 = CONST; switchInt (_2)`. +/// +/// This pass is very specifically targeted at *exactly* those patterns. +/// It can absolutely be replaced with a more general pass should we get one that +/// we can run in low optimization levels, but at the time of writing even in +/// optimized builds this wasn't simplified. +#[derive(Default)] +pub struct SwitchConst; + +impl<'tcx> MirPass<'tcx> for SwitchConst { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + for block in body.basic_blocks.as_mut_preserves_cfg() { + let switch_local = if let TerminatorKind::SwitchInt { discr, .. } = + &block.terminator().kind + && let Some(place) = discr.place() + && let Some(local) = place.as_local() + { + local + } else { + continue; + }; + + let new_operand = if let Some(statement) = block.statements.last() + && let StatementKind::Assign(place_and_rvalue) = &statement.kind + && let Some(local) = place_and_rvalue.0.as_local() + && local == switch_local + && let Rvalue::Use(operand) = &place_and_rvalue.1 + && let Operand::Constant(_) = operand + { + operand.clone() + } else { + continue; + }; + + if !tcx.consider_optimizing(|| format!("SwitchConst: switchInt(move {switch_local:?}")) + { + break; + } + + let TerminatorKind::SwitchInt { discr, .. } = &mut block.terminator_mut().kind else { + bug!("Somehow wasn't a switchInt any more?") + }; + *discr = new_operand; + } + } +} diff --git a/tests/mir-opt/pre-codegen/if_associated_const.check_bool.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/if_associated_const.check_bool.PreCodegen.after.mir new file mode 100644 index 0000000000000..7b0f114d06d2f --- /dev/null +++ b/tests/mir-opt/pre-codegen/if_associated_const.check_bool.PreCodegen.after.mir @@ -0,0 +1,23 @@ +// MIR for `check_bool` after PreCodegen + +fn check_bool() -> u32 { + let mut _0: u32; + + bb0: { + switchInt(const ::FLAG) -> [0: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 456_u32; + goto -> bb3; + } + + bb2: { + _0 = const 123_u32; + goto -> bb3; + } + + bb3: { + return; + } +} diff --git a/tests/mir-opt/pre-codegen/if_associated_const.check_int.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/if_associated_const.check_int.PreCodegen.after.mir new file mode 100644 index 0000000000000..1d367d1c593f1 --- /dev/null +++ b/tests/mir-opt/pre-codegen/if_associated_const.check_int.PreCodegen.after.mir @@ -0,0 +1,33 @@ +// MIR for `check_int` after PreCodegen + +fn check_int() -> u32 { + let mut _0: u32; + + bb0: { + switchInt(const ::VALUE) -> [1: bb1, 2: bb2, 3: bb3, otherwise: bb4]; + } + + bb1: { + _0 = const 123_u32; + goto -> bb5; + } + + bb2: { + _0 = const 456_u32; + goto -> bb5; + } + + bb3: { + _0 = const 789_u32; + goto -> bb5; + } + + bb4: { + _0 = const 0_u32; + goto -> bb5; + } + + bb5: { + return; + } +} diff --git a/tests/mir-opt/pre-codegen/if_associated_const.rs b/tests/mir-opt/pre-codegen/if_associated_const.rs new file mode 100644 index 0000000000000..86b99f9405f88 --- /dev/null +++ b/tests/mir-opt/pre-codegen/if_associated_const.rs @@ -0,0 +1,27 @@ +// skip-filecheck +//@ compile-flags: -O -Zmir-opt-level=2 -Cdebuginfo=2 + +#![crate_type = "lib"] + +pub trait TraitWithBool { + const FLAG: bool; +} + +// EMIT_MIR if_associated_const.check_bool.PreCodegen.after.mir +pub fn check_bool() -> u32 { + if T::FLAG { 123 } else { 456 } +} + +pub trait TraitWithInt { + const VALUE: i32; +} + +// EMIT_MIR if_associated_const.check_int.PreCodegen.after.mir +pub fn check_int() -> u32 { + match T::VALUE { + 1 => 123, + 2 => 456, + 3 => 789, + _ => 0, + } +} diff --git a/tests/mir-opt/single_use_consts.rs b/tests/mir-opt/single_use_consts.rs index ecb602c647a50..5623d6283df39 100644 --- a/tests/mir-opt/single_use_consts.rs +++ b/tests/mir-opt/single_use_consts.rs @@ -29,9 +29,7 @@ fn match_const() -> &'static str { fn if_const_debug() -> i32 { // CHECK-LABEL: fn if_const_debug( // CHECK: my_bool => const ::ASSOC_BOOL; - // FIXME: `if` forces a temporary (unlike `match`), so the const isn't direct - // CHECK: _3 = const ::ASSOC_BOOL; - // CHECK: switchInt(move _3) + // CHECK: switchInt(const ::ASSOC_BOOL) let my_bool = T::ASSOC_BOOL; do_whatever(); if my_bool { 7 } else { 42 }