Skip to content

Commit 23c586d

Browse files
committed
Transforms match into an assignment statement
1 parent 38050f8 commit 23c586d

9 files changed

+361
-115
lines changed

compiler/rustc_middle/src/mir/terminator.rs

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ impl SwitchTargets {
7474
pub fn target_for_value(&self, value: u128) -> BasicBlock {
7575
self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise())
7676
}
77+
78+
/// Returns true if all targets (including the fallback target) are distinct.
79+
#[inline]
80+
pub fn is_distinct(&self) -> bool {
81+
self.targets.iter().collect::<FxHashSet<_>>().len() == self.targets.len()
82+
}
7783
}
7884

7985
pub struct SwitchTargetsIter<'a> {

compiler/rustc_mir_transform/src/match_branches.rs

+214-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::*;
3-
use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
3+
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
44
use std::iter;
55

66
use super::simplify::simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838
should_cleanup = true;
3939
continue;
4040
}
41+
if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env)
42+
{
43+
should_cleanup = true;
44+
continue;
45+
}
4146
}
4247

4348
if should_cleanup {
@@ -48,7 +53,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4853

4954
trait SimplifyMatch<'tcx> {
5055
fn simplify(
51-
&self,
56+
&mut self,
5257
tcx: TyCtxt<'tcx>,
5358
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
5459
bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
@@ -72,7 +77,7 @@ trait SimplifyMatch<'tcx> {
7277
let source_info = bbs[switch_bb_idx].terminator().source_info;
7378
let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span));
7479

75-
// We already checked that first and second are different blocks,
80+
// We already checked that targets are different blocks,
7681
// and bb_idx has a different terminator from both of them.
7782
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty);
7883
let (_, first) = targets.iter().next().unwrap();
@@ -91,7 +96,7 @@ trait SimplifyMatch<'tcx> {
9196
}
9297

9398
fn can_simplify(
94-
&self,
99+
&mut self,
95100
tcx: TyCtxt<'tcx>,
96101
targets: &SwitchTargets,
97102
param_env: ParamEnv<'tcx>,
@@ -144,7 +149,7 @@ struct SimplifyToIf;
144149
/// ```
145150
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
146151
fn can_simplify(
147-
&self,
152+
&mut self,
148153
tcx: TyCtxt<'tcx>,
149154
targets: &SwitchTargets,
150155
param_env: ParamEnv<'tcx>,
@@ -250,3 +255,207 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250255
new_stmts.collect()
251256
}
252257
}
258+
259+
#[derive(Default)]
260+
struct SimplifyToExp {
261+
transfrom_types: Vec<TransfromType>,
262+
}
263+
264+
#[derive(Clone, Copy)]
265+
enum CompareType<'tcx, 'a> {
266+
Same(&'a StatementKind<'tcx>),
267+
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
268+
Discr(&'a Place<'tcx>, Ty<'tcx>),
269+
}
270+
271+
enum TransfromType {
272+
Same,
273+
Eq,
274+
Discr,
275+
}
276+
277+
impl From<CompareType<'_, '_>> for TransfromType {
278+
fn from(compare_type: CompareType<'_, '_>) -> Self {
279+
match compare_type {
280+
CompareType::Same(_) => TransfromType::Same,
281+
CompareType::Eq(_, _, _) => TransfromType::Eq,
282+
CompareType::Discr(_, _) => TransfromType::Discr,
283+
}
284+
}
285+
}
286+
287+
/// If we find that the value of match is the same as the assignment,
288+
/// merge a target block statements into the source block,
289+
/// using cast to transform different integer types.
290+
///
291+
/// For example:
292+
///
293+
/// ```ignore (MIR)
294+
/// bb0: {
295+
/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
296+
/// }
297+
///
298+
/// bb1: {
299+
/// unreachable;
300+
/// }
301+
///
302+
/// bb2: {
303+
/// _0 = const 1_i16;
304+
/// goto -> bb5;
305+
/// }
306+
///
307+
/// bb3: {
308+
/// _0 = const 2_i16;
309+
/// goto -> bb5;
310+
/// }
311+
///
312+
/// bb4: {
313+
/// _0 = const 3_i16;
314+
/// goto -> bb5;
315+
/// }
316+
/// ```
317+
///
318+
/// into:
319+
///
320+
/// ```ignore (MIR)
321+
/// bb0: {
322+
/// _0 = _3 as i16 (IntToInt);
323+
/// goto -> bb5;
324+
/// }
325+
/// ```
326+
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
327+
fn can_simplify(
328+
&mut self,
329+
tcx: TyCtxt<'tcx>,
330+
targets: &SwitchTargets,
331+
param_env: ParamEnv<'tcx>,
332+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
333+
) -> bool {
334+
if targets.iter().len() < 2 || targets.iter().len() > 64 {
335+
return false;
336+
}
337+
// We require that the possible target blocks all be distinct.
338+
if !targets.is_distinct() {
339+
return false;
340+
}
341+
if !bbs[targets.otherwise()].is_empty_unreachable() {
342+
return false;
343+
}
344+
let mut iter = targets.iter();
345+
let (first_val, first_target) = iter.next().unwrap();
346+
let first_terminator_kind = &bbs[first_target].terminator().kind;
347+
// Check that destinations are identical, and if not, then don't optimize this block
348+
if !targets
349+
.iter()
350+
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
351+
{
352+
return false;
353+
}
354+
355+
let first_stmts = &bbs[first_target].statements;
356+
let (second_val, second_target) = iter.next().unwrap();
357+
let second_stmts = &bbs[second_target].statements;
358+
359+
let mut compare_types = Vec::new();
360+
for (f, s) in iter::zip(first_stmts, second_stmts) {
361+
let compare_type = match (&f.kind, &s.kind) {
362+
// If two statements are exactly the same, we can optimize.
363+
(f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
364+
365+
// If two statements are assignments with the match values to the same place, we can optimize.
366+
(
367+
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
368+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
369+
) if lhs_f == lhs_s
370+
&& f_c.const_.ty() == s_c.const_.ty()
371+
&& f_c.const_.ty().is_integral() =>
372+
{
373+
match (
374+
f_c.const_.try_eval_scalar_int(tcx, param_env),
375+
s_c.const_.try_eval_scalar_int(tcx, param_env),
376+
) {
377+
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
378+
(Some(f), Some(s))
379+
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
380+
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
381+
{
382+
CompareType::Discr(lhs_f, f_c.const_.ty())
383+
}
384+
_ => return false,
385+
}
386+
}
387+
388+
// Otherwise we cannot optimize. Try another block.
389+
_ => return false,
390+
};
391+
compare_types.push(compare_type);
392+
}
393+
394+
for (other_val, other_target) in iter {
395+
let other_stmts = &bbs[other_target].statements;
396+
if compare_types.len() != other_stmts.len() {
397+
return false;
398+
}
399+
for (f, s) in iter::zip(&compare_types, other_stmts) {
400+
match (*f, &s.kind) {
401+
(CompareType::Same(f_s), s_s) if f_s == s_s => {}
402+
(
403+
CompareType::Eq(lhs_f, f_ty, val),
404+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
405+
) if lhs_f == lhs_s
406+
&& s_c.const_.ty() == f_ty
407+
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
408+
(
409+
CompareType::Discr(lhs_f, f_ty),
410+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
411+
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
412+
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
413+
return false;
414+
};
415+
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
416+
return false;
417+
}
418+
}
419+
_ => return false,
420+
}
421+
}
422+
}
423+
self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
424+
true
425+
}
426+
427+
fn new_stmts(
428+
&self,
429+
_tcx: TyCtxt<'tcx>,
430+
targets: &SwitchTargets,
431+
_param_env: ParamEnv<'tcx>,
432+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
433+
discr_local: Local,
434+
discr_ty: Ty<'tcx>,
435+
) -> Vec<Statement<'tcx>> {
436+
let (_, first) = targets.iter().next().unwrap();
437+
let first = &bbs[first];
438+
439+
let new_stmts =
440+
iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) {
441+
(TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(),
442+
(
443+
TransfromType::Discr,
444+
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
445+
) => {
446+
let operand = Operand::Copy(Place::from(discr_local));
447+
let r_val = if f_c.const_.ty() == discr_ty {
448+
Rvalue::Use(operand)
449+
} else {
450+
Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
451+
};
452+
Statement {
453+
source_info: s.source_info,
454+
kind: StatementKind::Assign(Box::new((*lhs, r_val))),
455+
}
456+
}
457+
_ => unreachable!(),
458+
});
459+
new_stmts.collect()
460+
}
461+
}

tests/codegen/match-optimized.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 {
2626
// CHECK-NEXT: store i8 1, ptr %_0, align 1
2727
// CHECK-NEXT: br label %[[EXIT]]
2828
// CHECK: [[C]]:
29-
// CHECK-NEXT: store i8 2, ptr %_0, align 1
29+
// CHECK-NEXT: store i8 3, ptr %_0, align 1
3030
// CHECK-NEXT: br label %[[EXIT]]
3131
match e {
3232
E::A => 0,
3333
E::B => 1,
34-
E::C => 2,
34+
E::C => 3,
3535
}
3636
}
3737

tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff

+33-28
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55
debug i => _1;
66
let mut _0: u128;
77
let mut _2: i128;
8+
+ let mut _3: i128;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const _;
16-
goto -> bb6;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const 1_u128;
25-
goto -> bb6;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_u128;
30-
goto -> bb6;
31-
}
32-
33-
bb5: {
34-
_0 = const 3_u128;
35-
goto -> bb6;
36-
}
37-
38-
bb6: {
12+
- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const _;
17+
- goto -> bb6;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const 1_u128;
26+
- goto -> bb6;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_u128;
31+
- goto -> bb6;
32+
- }
33+
-
34+
- bb5: {
35+
- _0 = const 3_u128;
36+
- goto -> bb6;
37+
- }
38+
-
39+
- bb6: {
40+
+ StorageLive(_3);
41+
+ _3 = move _2;
42+
+ _0 = _3 as u128 (IntToInt);
43+
+ StorageDead(_3);
3944
return;
4045
}
4146
}

0 commit comments

Comments
 (0)