Skip to content

Commit e752af7

Browse files
committed
Transforms a match containing negative numbers into an assignment statement as well
1 parent 1f061f4 commit e752af7

4 files changed

+100
-59
lines changed

compiler/rustc_mir_transform/src/match_branches.rs

+38-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::*;
33
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
4+
use rustc_target::abi::Size;
45
use std::iter;
56

67
use super::simplify::simplify_cfg;
@@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
6768
_ => unreachable!(),
6869
};
6970

70-
if !self.can_simplify(tcx, targets, param_env, bbs) {
71+
let discr_ty = discr.ty(local_decls, tcx);
72+
if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
7173
return false;
7274
}
7375

7476
// Take ownership of items now that we know we can optimize.
7577
let discr = discr.clone();
76-
let discr_ty = discr.ty(local_decls, tcx);
7778

7879
// Introduce a temporary for the discriminant value.
7980
let source_info = bbs[switch_bb_idx].terminator().source_info;
@@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
104105
targets: &SwitchTargets,
105106
param_env: ParamEnv<'tcx>,
106107
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
108+
discr_ty: Ty<'tcx>,
107109
) -> bool;
108110

109111
fn new_stmts(
@@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
157159
targets: &SwitchTargets,
158160
param_env: ParamEnv<'tcx>,
159161
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
162+
_discr_ty: Ty<'tcx>,
160163
) -> bool {
161164
if targets.iter().len() != 1 {
162165
return false;
@@ -268,7 +271,7 @@ struct SimplifyToExp {
268271
enum CompareType<'tcx, 'a> {
269272
Same(&'a StatementKind<'tcx>),
270273
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
271-
Discr(&'a Place<'tcx>, Ty<'tcx>),
274+
Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
272275
}
273276

274277
enum TransfromType {
@@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
282285
match compare_type {
283286
CompareType::Same(_) => TransfromType::Same,
284287
CompareType::Eq(_, _, _) => TransfromType::Eq,
285-
CompareType::Discr(_, _) => TransfromType::Discr,
288+
CompareType::Discr(_, _, _) => TransfromType::Discr,
286289
}
287290
}
288291
}
@@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
333336
targets: &SwitchTargets,
334337
param_env: ParamEnv<'tcx>,
335338
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
339+
discr_ty: Ty<'tcx>,
336340
) -> bool {
337341
if targets.iter().len() < 2 || targets.iter().len() > 64 {
338342
return false;
@@ -355,13 +359,19 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
355359
return false;
356360
}
357361

362+
let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
358363
let first_stmts = &bbs[first_target].statements;
359364
let (second_val, second_target) = target_iter.next().unwrap();
360365
let second_stmts = &bbs[second_target].statements;
361366
if first_stmts.len() != second_stmts.len() {
362367
return false;
363368
}
364369

370+
fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
371+
l.try_to_int(l.size()).unwrap()
372+
== ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
373+
}
374+
365375
let mut compare_types = Vec::new();
366376
for (f, s) in iter::zip(first_stmts, second_stmts) {
367377
let compare_type = match (&f.kind, &s.kind) {
@@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
382392
) {
383393
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
384394
(Some(f), Some(s))
385-
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
386-
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
395+
if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
396+
&& int_equal(f, first_val, discr_size)
397+
&& int_equal(s, second_val, discr_size))
398+
|| (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
399+
&& Some(s)
400+
== ScalarInt::try_from_uint(second_val, s.size())) =>
387401
{
388-
CompareType::Discr(lhs_f, f_c.const_.ty())
402+
CompareType::Discr(
403+
lhs_f,
404+
f_c.const_.ty(),
405+
f_c.const_.ty().is_signed() || discr_ty.is_signed(),
406+
)
407+
}
408+
_ => {
409+
return false;
389410
}
390-
_ => return false,
391411
}
392412
}
393413

@@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
413433
&& s_c.const_.ty() == f_ty
414434
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
415435
(
416-
CompareType::Discr(lhs_f, f_ty),
436+
CompareType::Discr(lhs_f, f_ty, is_signed),
417437
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
418438
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
419439
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
420440
return false;
421441
};
422-
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
423-
return false;
442+
if is_signed
443+
&& s_c.const_.ty().is_signed()
444+
&& int_equal(f, other_val, discr_size)
445+
{
446+
continue;
447+
}
448+
if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
449+
continue;
424450
}
451+
return false;
425452
}
426453
_ => return false,
427454
}

tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i8;
77
let mut _2: i16;
8+
+ let mut _3: i16;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const -3_i8;
20-
goto -> bb5;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i8;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i8;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const -3_i8;
21+
- goto -> bb5;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i8;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i8;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i8 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff

+28-23
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,37 @@
55
debug i => _1;
66
let mut _0: i16;
77
let mut _2: i8;
8+
+ let mut _3: i8;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const -3_i16;
20-
goto -> bb5;
21-
}
22-
23-
bb3: {
24-
_0 = const -1_i16;
25-
goto -> bb5;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_i16;
30-
goto -> bb5;
31-
}
32-
33-
bb5: {
12+
- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const -3_i16;
21+
- goto -> bb5;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const -1_i16;
26+
- goto -> bb5;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_i16;
31+
- goto -> bb5;
32+
- }
33+
-
34+
- bb5: {
35+
+ StorageLive(_3);
36+
+ _3 = move _2;
37+
+ _0 = _3 as i16 (IntToInt);
38+
+ StorageDead(_3);
3439
return;
3540
}
3641
}

tests/mir-opt/matches_reduce_branches.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ enum EnumAi8 {
204204
// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
205205
fn match_i8_i16(i: EnumAi8) -> i16 {
206206
// CHECK-LABEL: fn match_i8_i16(
207-
// CHECK: switchInt
207+
// CHECK-NOT: switchInt
208+
// CHECK: _0 = _3 as i16 (IntToInt);
209+
// CHECH: return
208210
match i {
209211
EnumAi8::A => -1,
210212
EnumAi8::B => 2,
@@ -233,7 +235,9 @@ enum EnumAi16 {
233235
// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
234236
fn match_i16_i8(i: EnumAi16) -> i8 {
235237
// CHECK-LABEL: fn match_i16_i8(
236-
// CHECK: switchInt
238+
// CHECK-NOT: switchInt
239+
// CHECK: _0 = _3 as i8 (IntToInt);
240+
// CHECH: return
237241
match i {
238242
EnumAi16::A => -1,
239243
EnumAi16::B => 2,

0 commit comments

Comments
 (0)