Skip to content

Commit b57b2bb

Browse files
committed
Move constants into the SwitchInt to be easier to see
1 parent a62e9f8 commit b57b2bb

10 files changed

+366
-340
lines changed

compiler/rustc_mir_transform/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ pub mod simplify;
110110
mod simplify_branches;
111111
mod simplify_comparison_integral;
112112
mod sroa;
113+
mod switch_const;
113114
mod uninhabited_enum_branching;
114115
mod unreachable_prop;
115116

@@ -600,6 +601,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
600601
&simplify::SimplifyLocals::AfterGVN,
601602
&dataflow_const_prop::DataflowConstProp,
602603
&const_debuginfo::ConstDebugInfo,
604+
// GVN & ConstProp often don't fixup unevaluatable constants
605+
&switch_const::SwitchConst,
603606
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
604607
&jump_threading::JumpThreading,
605608
&early_otherwise_branch::EarlyOtherwiseBranch,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//! A pass that makes `SwitchInt`-on-`const` more obvious to later code.
2+
3+
use rustc_middle::mir::*;
4+
use rustc_middle::ty::TyCtxt;
5+
6+
/// A `MirPass` for simplifying `if T::CONST`.
7+
///
8+
/// Today, MIR building for things like `if T::IS_ZST` introduce a constant
9+
/// for the copy of the bool, so it ends up in MIR as
10+
/// `_1 = CONST; switchInt (move _1)` or `_2 = CONST; switchInt (_2)`.
11+
///
12+
/// This pass is very specifically targeted at *exactly* those patterns.
13+
/// It can absolutely be replaced with a more general pass should we get one that
14+
/// we can run in low optimization levels, but at the time of writing even in
15+
/// optimized builds this wasn't simplified.
16+
#[derive(Default)]
17+
pub struct SwitchConst;
18+
19+
impl<'tcx> MirPass<'tcx> for SwitchConst {
20+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21+
for block in body.basic_blocks.as_mut_preserves_cfg() {
22+
let switch_local = if let TerminatorKind::SwitchInt { discr, .. } =
23+
&block.terminator().kind
24+
&& let Some(place) = discr.place()
25+
&& let Some(local) = place.as_local()
26+
{
27+
local
28+
} else {
29+
continue;
30+
};
31+
32+
let new_operand = if let Some(statement) = block.statements.last()
33+
&& let StatementKind::Assign(place_and_rvalue) = &statement.kind
34+
&& let Some(local) = place_and_rvalue.0.as_local()
35+
&& local == switch_local
36+
&& let Rvalue::Use(operand) = &place_and_rvalue.1
37+
&& let Operand::Constant(_) = operand
38+
{
39+
operand.clone()
40+
} else {
41+
continue;
42+
};
43+
44+
if !tcx.consider_optimizing(|| format!("SwitchConst: switchInt(move {switch_local:?}"))
45+
{
46+
break;
47+
}
48+
49+
let TerminatorKind::SwitchInt { discr, .. } = &mut block.terminator_mut().kind else {
50+
bug!("Somehow wasn't a switchInt any more?")
51+
};
52+
*discr = new_operand;
53+
}
54+
}
55+
}

tests/mir-opt/pre-codegen/if_associated_const.check_bool.PreCodegen.after.mir

+1-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22

33
fn check_bool() -> u32 {
44
let mut _0: u32;
5-
let mut _1: bool;
65

76
bb0: {
8-
StorageLive(_1);
9-
_1 = const <T as TraitWithBool>::FLAG;
10-
switchInt(move _1) -> [0: bb1, otherwise: bb2];
7+
switchInt(const <T as TraitWithBool>::FLAG) -> [0: bb1, otherwise: bb2];
118
}
129

1310
bb1: {
@@ -21,7 +18,6 @@ fn check_bool() -> u32 {
2118
}
2219

2320
bb3: {
24-
StorageDead(_1);
2521
return;
2622
}
2723
}

tests/mir-opt/pre-codegen/if_associated_const.check_int.PreCodegen.after.mir

+1-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22

33
fn check_int() -> u32 {
44
let mut _0: u32;
5-
let mut _1: i32;
65

76
bb0: {
8-
StorageLive(_1);
9-
_1 = const <T as TraitWithInt>::VALUE;
10-
switchInt(_1) -> [1: bb1, 2: bb2, 3: bb3, otherwise: bb4];
7+
switchInt(const <T as TraitWithInt>::VALUE) -> [1: bb1, 2: bb2, 3: bb3, otherwise: bb4];
118
}
129

1310
bb1: {
@@ -31,7 +28,6 @@ fn check_int() -> u32 {
3128
}
3229

3330
bb5: {
34-
StorageDead(_1);
3531
return;
3632
}
3733
}

tests/mir-opt/pre-codegen/slice_iter.enumerated_loop.PreCodegen.after.panic-abort.mir

+54-58
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
44
debug slice => _1;
55
debug f => _2;
66
let mut _0: ();
7-
let mut _13: std::slice::Iter<'_, T>;
7+
let mut _12: std::slice::Iter<'_, T>;
8+
let mut _13: std::iter::Enumerate<std::slice::Iter<'_, T>>;
89
let mut _14: std::iter::Enumerate<std::slice::Iter<'_, T>>;
9-
let mut _15: std::iter::Enumerate<std::slice::Iter<'_, T>>;
10-
let mut _16: &mut std::iter::Enumerate<std::slice::Iter<'_, T>>;
11-
let mut _17: std::option::Option<(usize, &T)>;
12-
let mut _18: isize;
13-
let mut _21: &impl Fn(usize, &T);
14-
let mut _22: (usize, &T);
15-
let _23: ();
10+
let mut _15: &mut std::iter::Enumerate<std::slice::Iter<'_, T>>;
11+
let mut _16: std::option::Option<(usize, &T)>;
12+
let mut _17: isize;
13+
let mut _20: &impl Fn(usize, &T);
14+
let mut _21: (usize, &T);
15+
let _22: ();
1616
scope 1 {
17-
debug iter => _15;
18-
let _19: usize;
19-
let _20: &T;
17+
debug iter => _14;
18+
let _18: usize;
19+
let _19: &T;
2020
scope 2 {
21-
debug i => _19;
22-
debug x => _20;
21+
debug i => _18;
22+
debug x => _19;
2323
}
2424
}
2525
scope 3 (inlined core::slice::<impl [T]>::iter) {
@@ -28,19 +28,18 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
2828
debug slice => _1;
2929
let _3: usize;
3030
let mut _5: std::ptr::NonNull<[T]>;
31-
let mut _8: bool;
31+
let mut _8: *mut T;
3232
let mut _9: *mut T;
33-
let mut _10: *mut T;
34-
let mut _12: *const T;
33+
let mut _11: *const T;
3534
scope 5 {
3635
debug len => _3;
3736
let _7: std::ptr::NonNull<T>;
3837
scope 6 {
3938
debug ptr => _7;
4039
scope 7 {
41-
let _11: *const T;
40+
let _10: *const T;
4241
scope 8 {
43-
debug end_or_len => _11;
42+
debug end_or_len => _10;
4443
}
4544
scope 14 (inlined without_provenance::<T>) {
4645
debug addr => _3;
@@ -51,7 +50,7 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
5150
debug self => _7;
5251
}
5352
scope 17 (inlined std::ptr::mut_ptr::<impl *mut T>::add) {
54-
debug self => _9;
53+
debug self => _8;
5554
debug count => _3;
5655
scope 18 {
5756
}
@@ -77,17 +76,17 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
7776
}
7877
}
7978
scope 19 (inlined <std::slice::Iter<'_, T> as Iterator>::enumerate) {
80-
debug self => _13;
79+
debug self => _12;
8180
scope 20 (inlined Enumerate::<std::slice::Iter<'_, T>>::new) {
82-
debug iter => _13;
81+
debug iter => _12;
8382
}
8483
}
8584
scope 21 (inlined <Enumerate<std::slice::Iter<'_, T>> as IntoIterator>::into_iter) {
86-
debug self => _14;
85+
debug self => _13;
8786
}
8887

8988
bb0: {
90-
StorageLive(_13);
89+
StorageLive(_12);
9190
StorageLive(_3);
9291
StorageLive(_7);
9392
StorageLive(_4);
@@ -99,62 +98,59 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
9998
_6 = _4 as *const T (PtrToPtr);
10099
_7 = NonNull::<T> { pointer: _6 };
101100
StorageDead(_5);
102-
StorageLive(_11);
103-
StorageLive(_8);
104-
_8 = const <T as std::mem::SizedTypeProperties>::IS_ZST;
105-
switchInt(move _8) -> [0: bb1, otherwise: bb2];
101+
StorageLive(_10);
102+
switchInt(const <T as std::mem::SizedTypeProperties>::IS_ZST) -> [0: bb1, otherwise: bb2];
106103
}
107104

108105
bb1: {
109-
StorageLive(_10);
110106
StorageLive(_9);
111-
_9 = _4 as *mut T (PtrToPtr);
112-
_10 = Offset(_9, _3);
107+
StorageLive(_8);
108+
_8 = _4 as *mut T (PtrToPtr);
109+
_9 = Offset(_8, _3);
110+
StorageDead(_8);
111+
_10 = move _9 as *const T (PointerCoercion(MutToConstPointer));
113112
StorageDead(_9);
114-
_11 = move _10 as *const T (PointerCoercion(MutToConstPointer));
115-
StorageDead(_10);
116113
goto -> bb3;
117114
}
118115

119116
bb2: {
120-
_11 = _3 as *const T (Transmute);
117+
_10 = _3 as *const T (Transmute);
121118
goto -> bb3;
122119
}
123120

124121
bb3: {
125-
StorageDead(_8);
126-
StorageLive(_12);
127-
_12 = _11;
128-
_13 = std::slice::Iter::<'_, T> { ptr: _7, end_or_len: move _12, _marker: const ZeroSized: PhantomData<&T> };
129-
StorageDead(_12);
122+
StorageLive(_11);
123+
_11 = _10;
124+
_12 = std::slice::Iter::<'_, T> { ptr: _7, end_or_len: move _11, _marker: const ZeroSized: PhantomData<&T> };
130125
StorageDead(_11);
126+
StorageDead(_10);
131127
StorageDead(_6);
132128
StorageDead(_4);
133129
StorageDead(_7);
134130
StorageDead(_3);
135-
_14 = Enumerate::<std::slice::Iter<'_, T>> { iter: _13, count: const 0_usize };
136-
StorageDead(_13);
137-
StorageLive(_15);
138-
_15 = _14;
131+
_13 = Enumerate::<std::slice::Iter<'_, T>> { iter: _12, count: const 0_usize };
132+
StorageDead(_12);
133+
StorageLive(_14);
134+
_14 = _13;
139135
goto -> bb4;
140136
}
141137

142138
bb4: {
143-
StorageLive(_17);
144139
StorageLive(_16);
145-
_16 = &mut _15;
146-
_17 = <Enumerate<std::slice::Iter<'_, T>> as Iterator>::next(move _16) -> [return: bb5, unwind unreachable];
140+
StorageLive(_15);
141+
_15 = &mut _14;
142+
_16 = <Enumerate<std::slice::Iter<'_, T>> as Iterator>::next(move _15) -> [return: bb5, unwind unreachable];
147143
}
148144

149145
bb5: {
150-
StorageDead(_16);
151-
_18 = discriminant(_17);
152-
switchInt(move _18) -> [0: bb6, 1: bb8, otherwise: bb10];
146+
StorageDead(_15);
147+
_17 = discriminant(_16);
148+
switchInt(move _17) -> [0: bb6, 1: bb8, otherwise: bb10];
153149
}
154150

155151
bb6: {
156-
StorageDead(_17);
157-
StorageDead(_15);
152+
StorageDead(_16);
153+
StorageDead(_14);
158154
drop(_2) -> [return: bb7, unwind unreachable];
159155
}
160156

@@ -163,19 +159,19 @@ fn enumerated_loop(_1: &[T], _2: impl Fn(usize, &T)) -> () {
163159
}
164160

165161
bb8: {
166-
_19 = (((_17 as Some).0: (usize, &T)).0: usize);
167-
_20 = (((_17 as Some).0: (usize, &T)).1: &T);
162+
_18 = (((_16 as Some).0: (usize, &T)).0: usize);
163+
_19 = (((_16 as Some).0: (usize, &T)).1: &T);
164+
StorageLive(_20);
165+
_20 = &_2;
168166
StorageLive(_21);
169-
_21 = &_2;
170-
StorageLive(_22);
171-
_22 = (_19, _20);
172-
_23 = <impl Fn(usize, &T) as Fn<(usize, &T)>>::call(move _21, move _22) -> [return: bb9, unwind unreachable];
167+
_21 = (_18, _19);
168+
_22 = <impl Fn(usize, &T) as Fn<(usize, &T)>>::call(move _20, move _21) -> [return: bb9, unwind unreachable];
173169
}
174170

175171
bb9: {
176-
StorageDead(_22);
177172
StorageDead(_21);
178-
StorageDead(_17);
173+
StorageDead(_20);
174+
StorageDead(_16);
179175
goto -> bb4;
180176
}
181177

0 commit comments

Comments
 (0)