Skip to content

Commit 2e0edc0

Browse files
committed
Auto merge of #75119 - simonvandel:early-otherwise, r=oli-obk
New MIR optimization pass to reduce branches on match of tuples of enums Fixes #68867 by adding a new pass that turns something like ```rust let x: Option<()>; let y: Option<()>; match (x,y) { (Some(_), Some(_)) => {0}, _ => {1} } ``` into something like ```rust let x: Option<()>; let y: Option<()>; let discriminant_x = // get discriminant of x let discriminant_y = // get discriminant of x if discriminant_x != discriminant_y {1} else {0} ``` The opt-diffs still have the old basic blocks like ``` bb3: { _8 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:21:21: 21:30 switchInt(move _8) -> [1_isize: bb7, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:21:21: 21:30 } bb4: { _9 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:22:23: 22:34 switchInt(move _9) -> [2_isize: bb8, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:22:23: 22:34 } bb5: { _10 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:23:23: 23:34 switchInt(move _10) -> [3_isize: bb9, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:23:23: 23:34 } ``` These do get removed on later passes. I'm not sure if I should include those passes in the test to make it clear?
2 parents 81e0270 + 0363694 commit 2e0edc0

13 files changed

+1387
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
use crate::{
2+
transform::{MirPass, MirSource},
3+
util::patch::MirPatch,
4+
};
5+
use rustc_middle::mir::*;
6+
use rustc_middle::ty::{Ty, TyCtxt};
7+
use std::{borrow::Cow, fmt::Debug};
8+
9+
use super::simplify::simplify_cfg;
10+
11+
/// This pass optimizes something like
12+
/// ```text
13+
/// let x: Option<()>;
14+
/// let y: Option<()>;
15+
/// match (x,y) {
16+
/// (Some(_), Some(_)) => {0},
17+
/// _ => {1}
18+
/// }
19+
/// ```
20+
/// into something like
21+
/// ```text
22+
/// let x: Option<()>;
23+
/// let y: Option<()>;
24+
/// let discriminant_x = // get discriminant of x
25+
/// let discriminant_y = // get discriminant of y
26+
/// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
27+
/// ```
28+
pub struct EarlyOtherwiseBranch;
29+
30+
impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
31+
fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
32+
if tcx.sess.opts.debugging_opts.mir_opt_level < 1 {
33+
return;
34+
}
35+
trace!("running EarlyOtherwiseBranch on {:?}", source);
36+
// we are only interested in this bb if the terminator is a switchInt
37+
let bbs_with_switch =
38+
body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
39+
40+
let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
41+
.flat_map(|(bb_idx, bb)| {
42+
let switch = bb.terminator();
43+
let helper = Helper { body, tcx };
44+
let infos = helper.go(bb, switch)?;
45+
Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
46+
})
47+
.collect();
48+
49+
let should_cleanup = !opts_to_apply.is_empty();
50+
51+
for opt_to_apply in opts_to_apply {
52+
trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
53+
54+
let statements_before =
55+
body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
56+
let end_of_block_location = Location {
57+
block: opt_to_apply.basic_block_first_switch,
58+
statement_index: statements_before,
59+
};
60+
61+
let mut patch = MirPatch::new(body);
62+
63+
// create temp to store second discriminant in
64+
let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
65+
let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
66+
let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
67+
68+
patch.add_statement(
69+
end_of_block_location,
70+
StatementKind::StorageLive(second_discriminant_temp),
71+
);
72+
73+
// create assignment of discriminant
74+
let place_of_adt_to_get_discriminant_of =
75+
opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
76+
patch.add_assign(
77+
end_of_block_location,
78+
Place::from(second_discriminant_temp),
79+
Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
80+
);
81+
82+
// create temp to store NotEqual comparison between the two discriminants
83+
let not_equal = BinOp::Ne;
84+
let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
85+
let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
86+
patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
87+
88+
// create NotEqual comparison between the two discriminants
89+
let first_descriminant_place =
90+
opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
91+
let not_equal_rvalue = Rvalue::BinaryOp(
92+
not_equal,
93+
Operand::Copy(Place::from(second_discriminant_temp)),
94+
Operand::Copy(Place::from(first_descriminant_place)),
95+
);
96+
patch.add_statement(
97+
end_of_block_location,
98+
StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)),
99+
);
100+
101+
let (mut targets_to_jump_to, values_to_jump_to): (Vec<_>, Vec<_>) = opt_to_apply
102+
.infos
103+
.iter()
104+
.flat_map(|x| x.second_switch_info.targets_with_values.iter())
105+
.cloned()
106+
.unzip();
107+
108+
// add otherwise case in the end
109+
targets_to_jump_to.push(opt_to_apply.infos[0].first_switch_info.otherwise_bb);
110+
// new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
111+
let new_switch_data = BasicBlockData::new(Some(Terminator {
112+
source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
113+
kind: TerminatorKind::SwitchInt {
114+
// the first and second discriminants are equal, so just pick one
115+
discr: Operand::Copy(first_descriminant_place),
116+
switch_ty: discr_type,
117+
values: Cow::from(values_to_jump_to),
118+
targets: targets_to_jump_to,
119+
},
120+
}));
121+
122+
let new_switch_bb = patch.new_block(new_switch_data);
123+
124+
// switch on the NotEqual. If true, then jump to the `otherwise` case.
125+
// If false, then jump to a basic block that then jumps to the correct disciminant case
126+
let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
127+
let false_case = new_switch_bb;
128+
patch.patch_terminator(
129+
opt_to_apply.basic_block_first_switch,
130+
TerminatorKind::if_(
131+
tcx,
132+
Operand::Move(Place::from(not_equal_temp)),
133+
true_case,
134+
false_case,
135+
),
136+
);
137+
138+
// generate StorageDead for the second_discriminant_temp not in use anymore
139+
patch.add_statement(
140+
end_of_block_location,
141+
StatementKind::StorageDead(second_discriminant_temp),
142+
);
143+
144+
// Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
145+
for bb in [false_case, true_case].iter() {
146+
patch.add_statement(
147+
Location { block: *bb, statement_index: 0 },
148+
StatementKind::StorageDead(not_equal_temp),
149+
);
150+
}
151+
152+
patch.apply(body);
153+
}
154+
155+
// Since this optimization adds new basic blocks and invalidates others,
156+
// clean up the cfg to make it nicer for other passes
157+
if should_cleanup {
158+
simplify_cfg(body);
159+
}
160+
}
161+
}
162+
163+
fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool {
164+
match terminator.kind {
165+
TerminatorKind::SwitchInt { .. } => true,
166+
_ => false,
167+
}
168+
}
169+
170+
struct Helper<'a, 'tcx> {
171+
body: &'a Body<'tcx>,
172+
tcx: TyCtxt<'tcx>,
173+
}
174+
175+
#[derive(Debug, Clone)]
176+
struct SwitchDiscriminantInfo<'tcx> {
177+
/// Type of the discriminant being switched on
178+
discr_ty: Ty<'tcx>,
179+
/// The basic block that the otherwise branch points to
180+
otherwise_bb: BasicBlock,
181+
/// Target along with the value being branched from. Otherwise is not included
182+
targets_with_values: Vec<(BasicBlock, u128)>,
183+
discr_source_info: SourceInfo,
184+
/// The place of the discriminant used in the switch
185+
discr_used_in_switch: Place<'tcx>,
186+
/// The place of the adt that has its discriminant read
187+
place_of_adt_discr_read: Place<'tcx>,
188+
/// The type of the adt that has its discriminant read
189+
type_adt_matched_on: Ty<'tcx>,
190+
}
191+
192+
#[derive(Debug)]
193+
struct OptimizationToApply<'tcx> {
194+
infos: Vec<OptimizationInfo<'tcx>>,
195+
/// Basic block of the original first switch
196+
basic_block_first_switch: BasicBlock,
197+
}
198+
199+
#[derive(Debug)]
200+
struct OptimizationInfo<'tcx> {
201+
/// Info about the first switch and discriminant
202+
first_switch_info: SwitchDiscriminantInfo<'tcx>,
203+
/// Info about the second switch and discriminant
204+
second_switch_info: SwitchDiscriminantInfo<'tcx>,
205+
}
206+
207+
impl<'a, 'tcx> Helper<'a, 'tcx> {
208+
pub fn go(
209+
&self,
210+
bb: &BasicBlockData<'tcx>,
211+
switch: &Terminator<'tcx>,
212+
) -> Option<Vec<OptimizationInfo<'tcx>>> {
213+
// try to find the statement that defines the discriminant that is used for the switch
214+
let discr = self.find_switch_discriminant_info(bb, switch)?;
215+
216+
// go through each target, finding a discriminant read, and a switch
217+
let results = discr.targets_with_values.iter().map(|(target, value)| {
218+
self.find_discriminant_switch_pairing(&discr, target.clone(), value.clone())
219+
});
220+
221+
// if the optimization did not apply for one of the targets, then abort
222+
if results.clone().any(|x| x.is_none()) || results.len() == 0 {
223+
trace!("NO: not all of the targets matched the pattern for optimization");
224+
return None;
225+
}
226+
227+
Some(results.flatten().collect())
228+
}
229+
230+
fn find_discriminant_switch_pairing(
231+
&self,
232+
discr_info: &SwitchDiscriminantInfo<'tcx>,
233+
target: BasicBlock,
234+
value: u128,
235+
) -> Option<OptimizationInfo<'tcx>> {
236+
let bb = &self.body.basic_blocks()[target];
237+
// find switch
238+
let terminator = bb.terminator();
239+
if is_switch(terminator) {
240+
let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
241+
242+
// the types of the two adts matched on have to be equalfor this optimization to apply
243+
if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
244+
trace!(
245+
"NO: types do not match. LHS: {:?}, RHS: {:?}",
246+
discr_info.type_adt_matched_on,
247+
this_bb_discr_info.type_adt_matched_on
248+
);
249+
return None;
250+
}
251+
252+
// the otherwise branch of the two switches have to point to the same bb
253+
if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
254+
trace!("NO: otherwise target is not the same");
255+
return None;
256+
}
257+
258+
// check that the value being matched on is the same. The
259+
if this_bb_discr_info.targets_with_values.iter().find(|x| x.1 == value).is_none() {
260+
trace!("NO: values being matched on are not the same");
261+
return None;
262+
}
263+
264+
// only allow optimization if the left and right of the tuple being matched are the same variants.
265+
// so the following should not optimize
266+
// ```rust
267+
// let x: Option<()>;
268+
// let y: Option<()>;
269+
// match (x,y) {
270+
// (Some(_), None) => {},
271+
// _ => {}
272+
// }
273+
// ```
274+
// We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
275+
if !(this_bb_discr_info.targets_with_values.len() == 1
276+
&& this_bb_discr_info.targets_with_values[0].1 == value)
277+
{
278+
trace!(
279+
"NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
280+
);
281+
return None;
282+
}
283+
284+
// if we reach this point, the optimization applies, and we should be able to optimize this case
285+
// store the info that is needed to apply the optimization
286+
287+
Some(OptimizationInfo {
288+
first_switch_info: discr_info.clone(),
289+
second_switch_info: this_bb_discr_info,
290+
})
291+
} else {
292+
None
293+
}
294+
}
295+
296+
fn find_switch_discriminant_info(
297+
&self,
298+
bb: &BasicBlockData<'tcx>,
299+
switch: &Terminator<'tcx>,
300+
) -> Option<SwitchDiscriminantInfo<'tcx>> {
301+
match &switch.kind {
302+
TerminatorKind::SwitchInt { discr, targets, values, .. } => {
303+
let discr_local = discr.place()?.as_local()?;
304+
// the declaration of the discriminant read. Place of this read is being used in the switch
305+
let discr_decl = &self.body.local_decls()[discr_local];
306+
let discr_ty = discr_decl.ty;
307+
// the otherwise target lies as the last element
308+
let otherwise_bb = targets.get(values.len())?.clone();
309+
let targets_with_values = targets
310+
.iter()
311+
.zip(values.iter())
312+
.map(|(t, v)| (t.clone(), v.clone()))
313+
.collect();
314+
315+
// find the place of the adt where the discriminant is being read from
316+
// assume this is the last statement of the block
317+
let place_of_adt_discr_read = match bb.statements.last()?.kind {
318+
StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
319+
Some(adt_place)
320+
}
321+
_ => None,
322+
}?;
323+
324+
let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
325+
326+
Some(SwitchDiscriminantInfo {
327+
discr_used_in_switch: discr.place()?,
328+
discr_ty,
329+
otherwise_bb,
330+
targets_with_values,
331+
discr_source_info: discr_decl.source_info,
332+
place_of_adt_discr_read,
333+
type_adt_matched_on,
334+
})
335+
}
336+
_ => unreachable!("must only be passed terminator that is a switch"),
337+
}
338+
}
339+
}

compiler/rustc_mir/src/transform/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod copy_prop;
2626
pub mod deaggregator;
2727
pub mod dest_prop;
2828
pub mod dump_mir;
29+
pub mod early_otherwise_branch;
2930
pub mod elaborate_drops;
3031
pub mod generator;
3132
pub mod inline;
@@ -465,6 +466,7 @@ fn run_optimization_passes<'tcx>(
465466
&instcombine::InstCombine,
466467
&const_prop::ConstProp,
467468
&simplify_branches::SimplifyBranches::new("after-const-prop"),
469+
&early_otherwise_branch::EarlyOtherwiseBranch,
468470
&simplify_comparison_integral::SimplifyComparisonIntegral,
469471
&simplify_try::SimplifyArmIdentity,
470472
&simplify_try::SimplifyBranchSame,

0 commit comments

Comments
 (0)