Skip to content

Commit f79fd40

Browse files
committed
Transforms a match containing negative numbers into an assignment statement as well
1 parent 70ca429 commit f79fd40

4 files changed

+106
-59
lines changed

compiler/rustc_mir_transform/src/match_branches.rs

+44-11
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ trait SimplifyMatch<'tcx> {
6767
_ => unreachable!(),
6868
};
6969

70-
if !self.can_simplify(tcx, targets, param_env, bbs) {
70+
let discr_ty = discr.ty(local_decls, tcx);
71+
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
7172
return false;
7273
}
7374

7475
// Take ownership of items now that we know we can optimize.
7576
let discr = discr.clone();
76-
let discr_ty = discr.ty(local_decls, tcx);
7777

7878
// Introduce a temporary for the discriminant value.
7979
let source_info = bbs[switch_bb_idx].terminator().source_info;
@@ -103,6 +103,7 @@ trait SimplifyMatch<'tcx> {
103103
targets: &SwitchTargets,
104104
param_env: ParamEnv<'tcx>,
105105
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
106+
discr_ty: Ty<'tcx>,
106107
) -> bool;
107108

108109
fn new_stmts(
@@ -156,6 +157,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
156157
targets: &SwitchTargets,
157158
param_env: ParamEnv<'tcx>,
158159
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
160+
_discr_ty: Ty<'tcx>,
159161
) -> bool {
160162
if targets.iter().len() != 1 {
161163
return false;
@@ -267,7 +269,7 @@ struct SimplifyToExp {
267269
enum CompareType<'tcx, 'a> {
268270
Same(&'a StatementKind<'tcx>),
269271
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
270-
Discr(&'a Place<'tcx>, Ty<'tcx>),
272+
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
271273
}
272274

273275
enum TransfromType {
@@ -281,7 +283,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
281283
match compare_type {
282284
CompareType::Same(_) => TransfromType::Same,
283285
CompareType::Eq(_, _, _) => TransfromType::Eq,
284-
CompareType::Discr(_, _) => TransfromType::Discr,
286+
CompareType::Discr(_, _, _) => TransfromType::Discr,
285287
}
286288
}
287289
}
@@ -332,6 +334,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
332334
targets: &SwitchTargets,
333335
param_env: ParamEnv<'tcx>,
334336
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
337+
discr_ty: Ty<'tcx>,
335338
) -> bool {
336339
if targets.iter().len() < 2 || targets.iter().len() > 64 {
337340
return false;
@@ -354,6 +357,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354357
return false;
355358
}
356359

360+
let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
357361
let first_stmts = &bbs[first_target].statements;
358362
let (second_val, second_target) = iter.next().unwrap();
359363
let second_stmts = &bbs[second_target].statements;
@@ -381,12 +385,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
381385
) {
382386
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
383387
(Some(f), Some(s))
384-
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
385-
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
388+
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
389+
&& f.try_to_int(f.size()).unwrap()
390+
== ScalarInt::try_from_uint(first_val, discr_size)
391+
.unwrap()
392+
.try_to_int(discr_size)
393+
.unwrap()
394+
&& s.try_to_int(s.size()).unwrap()
395+
== ScalarInt::try_from_uint(second_val, discr_size)
396+
.unwrap()
397+
.try_to_int(discr_size)
398+
.unwrap())
399+
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
400+
&& Some(s)
401+
== ScalarInt::try_from_uint(second_val, s.size())) =>
386402
{
387-
CompareType::Discr(lhs_f, f_c.const_.ty())
403+
CompareType::Discr(
404+
lhs_f,
405+
f_c.const_.ty(),
406+
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
407+
)
408+
}
409+
_ => {
410+
return false;
388411
}
389-
_ => return false,
390412
}
391413
}
392414

@@ -411,15 +433,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
411433
&& s_c.const_.ty() == f_ty
412434
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
413435
(
414-
CompareType::Discr(lhs_f, f_ty),
436+
CompareType::Discr(lhs_f, f_ty, is_signed),
415437
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
416438
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
417439
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
418440
return false;
419441
};
420-
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
421-
return false;
442+
if is_signed
443+
&& s_c.const_.ty().is_signed()
444+
&& f.try_to_int(f.size()).unwrap()
445+
== ScalarInt::try_from_uint(other_val, discr_size)
446+
.unwrap()
447+
.try_to_int(discr_size)
448+
.unwrap()
449+
{
450+
continue;
451+
}
452+
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
453+
continue;
422454
}
455+
return false;
423456
}
424457
_ => return false,
425458
}

tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i8;
77
let mut _2: i16;
8+
+ let mut _3: i16;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const -3_i8;
20-
goto -> bb5;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i8;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i8;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const -3_i8;
21+
- goto -> bb5;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i8;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i8;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i8 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i16;
77
let mut _2: i8;
8+
+ let mut _3: i8;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const -3_i16;
20-
goto -> bb5;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i16;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i16;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const -3_i16;
21+
- goto -> bb5;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i16;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i16;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i16 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ enum EnumAi8 {
204204
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
205205
fn match_i8_i16(i: EnumAi8) -> i16 {
206206
// CHECK-LABEL: fn match_i8_i16(
207-
// CHECK: switchInt
207+
// CHECK-NOT: switchInt
208+
// CHECK: _0 = _3 as i16 (IntToInt);
209+
// CHECH: return
208210
match i {
209211
EnumAi8::A => -1,
210212
EnumAi8::B => 2,
@@ -233,7 +235,9 @@ enum EnumAi16 {
233235
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
234236
fn match_i16_i8(i: EnumAi16) -> i8 {
235237
// CHECK-LABEL: fn match_i16_i8(
236-
// CHECK: switchInt
238+
// CHECK-NOT: switchInt
239+
// CHECK: _0 = _3 as i8 (IntToInt);
240+
// CHECH: return
237241
match i {
238242
EnumAi16::A => -1,
239243
EnumAi16::B => 2,

0 commit comments

Comments
 (0)