Skip to content

Commit f21eda9

Browse files
committed
Transforms a match containing negative numbers into an assignment statement as well
1 parent 23c586d commit f21eda9

4 files changed

+106
-59
lines changed

compiler/rustc_mir_transform/src/match_branches.rs

+44-11
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> {
6565
_ => unreachable!(),
6666
};
6767

68-
if !self.can_simplify(tcx, targets, param_env, bbs) {
68+
let discr_ty = discr.ty(local_decls, tcx);
69+
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
6970
return false;
7071
}
7172

7273
// Take ownership of items now that we know we can optimize.
7374
let discr = discr.clone();
74-
let discr_ty = discr.ty(local_decls, tcx);
7575

7676
// Introduce a temporary for the discriminant value.
7777
let source_info = bbs[switch_bb_idx].terminator().source_info;
@@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> {
101101
targets: &SwitchTargets,
102102
param_env: ParamEnv<'tcx>,
103103
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
104+
discr_ty: Ty<'tcx>,
104105
) -> bool;
105106

106107
fn new_stmts(
@@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
154155
targets: &SwitchTargets,
155156
param_env: ParamEnv<'tcx>,
156157
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
158+
_discr_ty: Ty<'tcx>,
157159
) -> bool {
158160
if targets.iter().len() != 1 {
159161
return false;
@@ -265,7 +267,7 @@ struct SimplifyToExp {
265267
enum CompareType<'tcx, 'a> {
266268
Same(&'a StatementKind<'tcx>),
267269
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
268-
Discr(&'a Place<'tcx>, Ty<'tcx>),
270+
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
269271
}
270272

271273
enum TransfromType {
@@ -279,7 +281,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
279281
match compare_type {
280282
CompareType::Same(_) => TransfromType::Same,
281283
CompareType::Eq(_, _, _) => TransfromType::Eq,
282-
CompareType::Discr(_, _) => TransfromType::Discr,
284+
CompareType::Discr(_, _, _) => TransfromType::Discr,
283285
}
284286
}
285287
}
@@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
330332
targets: &SwitchTargets,
331333
param_env: ParamEnv<'tcx>,
332334
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
335+
discr_ty: Ty<'tcx>,
333336
) -> bool {
334337
if targets.iter().len() < 2 || targets.iter().len() > 64 {
335338
return false;
@@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
352355
return false;
353356
}
354357

358+
let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
355359
let first_stmts = &bbs[first_target].statements;
356360
let (second_val, second_target) = iter.next().unwrap();
357361
let second_stmts = &bbs[second_target].statements;
@@ -376,12 +380,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
376380
) {
377381
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
378382
(Some(f), Some(s))
379-
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
380-
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
383+
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
384+
&& f.try_to_int(f.size()).unwrap()
385+
== ScalarInt::try_from_uint(first_val, discr_size)
386+
.unwrap()
387+
.try_to_int(discr_size)
388+
.unwrap()
389+
&& s.try_to_int(s.size()).unwrap()
390+
== ScalarInt::try_from_uint(second_val, discr_size)
391+
.unwrap()
392+
.try_to_int(discr_size)
393+
.unwrap())
394+
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
395+
&& Some(s)
396+
== ScalarInt::try_from_uint(second_val, s.size())) =>
381397
{
382-
CompareType::Discr(lhs_f, f_c.const_.ty())
398+
CompareType::Discr(
399+
lhs_f,
400+
f_c.const_.ty(),
401+
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
402+
)
403+
}
404+
_ => {
405+
return false;
383406
}
384-
_ => return false,
385407
}
386408
}
387409

@@ -406,15 +428,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
406428
&& s_c.const_.ty() == f_ty
407429
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
408430
(
409-
CompareType::Discr(lhs_f, f_ty),
431+
CompareType::Discr(lhs_f, f_ty, is_signed),
410432
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
411433
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
412434
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
413435
return false;
414436
};
415-
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
416-
return false;
437+
if is_signed
438+
&& s_c.const_.ty().is_signed()
439+
&& f.try_to_int(f.size()).unwrap()
440+
== ScalarInt::try_from_uint(other_val, discr_size)
441+
.unwrap()
442+
.try_to_int(discr_size)
443+
.unwrap()
444+
{
445+
continue;
446+
}
447+
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
448+
continue;
417449
}
450+
return false;
418451
}
419452
_ => return false,
420453
}

tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i8;
77
let mut _2: i16;
8+
+ let mut _3: i16;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const -3_i8;
16-
goto -> bb5;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i8;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i8;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const -3_i8;
17+
- goto -> bb5;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i8;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i8;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i8 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i16;
77
let mut _2: i8;
8+
+ let mut _3: i8;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const -3_i16;
16-
goto -> bb5;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i16;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i16;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const -3_i16;
17+
- goto -> bb5;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i16;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i16;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i16 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ enum EnumAi8 {
131131
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
132132
fn match_i8_i16(i: EnumAi8) -> i16 {
133133
// CHECK-LABEL: fn match_i8_i16(
134-
// CHECK: switchInt
134+
// CHECK-NOT: switchInt
135+
// CHECK: _0 = _3 as i16 (IntToInt);
136+
// CHECH: return
135137
match i {
136138
EnumAi8::A => -1,
137139
EnumAi8::B => 2,
@@ -160,7 +162,9 @@ enum EnumAi16 {
160162
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
161163
fn match_i16_i8(i: EnumAi16) -> i8 {
162164
// CHECK-LABEL: fn match_i16_i8(
163-
// CHECK: switchInt
165+
// CHECK-NOT: switchInt
166+
// CHECK: _0 = _3 as i8 (IntToInt);
167+
// CHECH: return
164168
match i {
165169
EnumAi16::A => -1,
166170
EnumAi16::B => 2,

0 commit comments

Comments
 (0)