Skip to content

Commit 7a47635

Browse files
committed
Transforms a match containing negative numbers into an assignment statement as well
1 parent eccc782 commit 7a47635

4 files changed

+106
-59
lines changed

compiler/rustc_mir_transform/src/match_branches.rs

+44-11
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> {
6565
_ => unreachable!(),
6666
};
6767

68-
if !self.can_simplify(tcx, targets, param_env, bbs) {
68+
let discr_ty = discr.ty(local_decls, tcx);
69+
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
6970
return false;
7071
}
7172

7273
// Take ownership of items now that we know we can optimize.
7374
let discr = discr.clone();
74-
let discr_ty = discr.ty(local_decls, tcx);
7575

7676
// Introduce a temporary for the discriminant value.
7777
let source_info = bbs[switch_bb_idx].terminator().source_info;
@@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> {
101101
targets: &SwitchTargets,
102102
param_env: ParamEnv<'tcx>,
103103
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
104+
discr_ty: Ty<'tcx>,
104105
) -> bool;
105106

106107
fn new_stmts(
@@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
154155
targets: &SwitchTargets,
155156
param_env: ParamEnv<'tcx>,
156157
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
158+
_discr_ty: Ty<'tcx>,
157159
) -> bool {
158160
if targets.iter().len() != 1 {
159161
return false;
@@ -265,7 +267,7 @@ struct SimplifyToExp {
265267
enum CompareType<'tcx, 'a> {
266268
Same(&'a StatementKind<'tcx>),
267269
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
268-
Discr(&'a Place<'tcx>, Ty<'tcx>),
270+
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
269271
}
270272

271273
enum TransfromType {
@@ -279,7 +281,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
279281
match compare_type {
280282
CompareType::Same(_) => TransfromType::Same,
281283
CompareType::Eq(_, _, _) => TransfromType::Eq,
282-
CompareType::Discr(_, _) => TransfromType::Discr,
284+
CompareType::Discr(_, _, _) => TransfromType::Discr,
283285
}
284286
}
285287
}
@@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
330332
targets: &SwitchTargets,
331333
param_env: ParamEnv<'tcx>,
332334
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
335+
discr_ty: Ty<'tcx>,
333336
) -> bool {
334337
if targets.iter().len() < 2 || targets.iter().len() > 64 {
335338
return false;
@@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
352355
return false;
353356
}
354357

358+
let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
355359
let first_stmts = &bbs[first_target].statements;
356360
let (second_val, second_target) = iter.next().unwrap();
357361
let second_stmts = &bbs[second_target].statements;
@@ -379,12 +383,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
379383
) {
380384
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
381385
(Some(f), Some(s))
382-
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
383-
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
386+
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
387+
&& f.try_to_int(f.size()).unwrap()
388+
== ScalarInt::try_from_uint(first_val, discr_size)
389+
.unwrap()
390+
.try_to_int(discr_size)
391+
.unwrap()
392+
&& s.try_to_int(s.size()).unwrap()
393+
== ScalarInt::try_from_uint(second_val, discr_size)
394+
.unwrap()
395+
.try_to_int(discr_size)
396+
.unwrap())
397+
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
398+
&& Some(s)
399+
== ScalarInt::try_from_uint(second_val, s.size())) =>
384400
{
385-
CompareType::Discr(lhs_f, f_c.const_.ty())
401+
CompareType::Discr(
402+
lhs_f,
403+
f_c.const_.ty(),
404+
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
405+
)
406+
}
407+
_ => {
408+
return false;
386409
}
387-
_ => return false,
388410
}
389411
}
390412

@@ -409,15 +431,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
409431
&& s_c.const_.ty() == f_ty
410432
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
411433
(
412-
CompareType::Discr(lhs_f, f_ty),
434+
CompareType::Discr(lhs_f, f_ty, is_signed),
413435
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
414436
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
415437
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
416438
return false;
417439
};
418-
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
419-
return false;
440+
if is_signed
441+
&& s_c.const_.ty().is_signed()
442+
&& f.try_to_int(f.size()).unwrap()
443+
== ScalarInt::try_from_uint(other_val, discr_size)
444+
.unwrap()
445+
.try_to_int(discr_size)
446+
.unwrap()
447+
{
448+
continue;
449+
}
450+
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
451+
continue;
420452
}
453+
return false;
421454
}
422455
_ => return false,
423456
}

tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i8;
77
let mut _2: i16;
8+
+ let mut _3: i16;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const -3_i8;
16-
goto -> bb5;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i8;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i8;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const -3_i8;
17+
- goto -> bb5;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i8;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i8;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i8 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i16;
77
let mut _2: i8;
8+
+ let mut _3: i8;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const -3_i16;
16-
goto -> bb5;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i16;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i16;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const -3_i16;
17+
- goto -> bb5;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i16;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i16;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i16 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ enum EnumAi8 {
144144
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
145145
fn match_i8_i16(i: EnumAi8) -> i16 {
146146
// CHECK-LABEL: fn match_i8_i16(
147-
// CHECK: switchInt
147+
// CHECK-NOT: switchInt
148+
// CHECK: _0 = _3 as i16 (IntToInt);
149+
// CHECH: return
148150
match i {
149151
EnumAi8::A => -1,
150152
EnumAi8::B => 2,
@@ -173,7 +175,9 @@ enum EnumAi16 {
173175
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
174176
fn match_i16_i8(i: EnumAi16) -> i8 {
175177
// CHECK-LABEL: fn match_i16_i8(
176-
// CHECK: switchInt
178+
// CHECK-NOT: switchInt
179+
// CHECK: _0 = _3 as i8 (IntToInt);
180+
// CHECH: return
177181
match i {
178182
EnumAi16::A => -1,
179183
EnumAi16::B => 2,

0 commit comments

Comments
 (0)