Skip to content

Commit 738ed9b

Browse files
committed
Fix #76803
Check that the variant index matches the target value from the SwitchInt we came from
1 parent cad050b commit 738ed9b

File tree

3 files changed

+55
-44
lines changed

3 files changed

+55
-44
lines changed

compiler/rustc_mir/src/transform/simplify_try.rs

+33-23
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
1616
use rustc_middle::mir::*;
1717
use rustc_middle::ty::{self, List, Ty, TyCtxt};
1818
use rustc_target::abi::VariantIdx;
19-
use std::iter::{Enumerate, Peekable};
19+
use std::iter::{once, Enumerate, Peekable};
2020
use std::slice::Iter;
2121

2222
/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
@@ -551,6 +551,12 @@ struct SimplifyBranchSameOptimization {
551551
bb_to_opt_terminator: BasicBlock,
552552
}
553553

554+
struct SwitchTargetAndValue {
555+
target: BasicBlock,
556+
// None in case of the `otherwise` case
557+
value: Option<u128>,
558+
}
559+
554560
struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
555561
body: &'a Body<'tcx>,
556562
tcx: TyCtxt<'tcx>,
@@ -562,8 +568,15 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
562568
.basic_blocks()
563569
.iter_enumerated()
564570
.filter_map(|(bb_idx, bb)| {
565-
let (discr_switched_on, targets) = match &bb.terminator().kind {
566-
TerminatorKind::SwitchInt { targets, discr, .. } => (discr, targets),
571+
let (discr_switched_on, targets_and_values):(_, Vec<_>) = match &bb.terminator().kind {
572+
TerminatorKind::SwitchInt { targets, discr, values, .. } => {
573+
// if values.len() == targets.len() - 1, we need to include None where no value is present
574+
// such that the zip does not throw away targets. If no `otherwise` case is in targets, the zip will simply throw away the added None
575+
let values_extended = values.iter().map(|x|Some(*x)).chain(once(None));
576+
let targets_and_values = targets.iter().zip(values_extended)
577+
.map(|(target, value)| SwitchTargetAndValue{target:*target, value:value})
578+
.collect();
579+
(discr, targets_and_values)},
567580
_ => return None,
568581
};
569582

@@ -587,9 +600,9 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
587600
},
588601
};
589602

590-
let mut iter_bbs_reachable = targets
603+
let mut iter_bbs_reachable = targets_and_values
591604
.iter()
592-
.map(|idx| (*idx, &self.body.basic_blocks()[*idx]))
605+
.map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target]))
593606
.filter(|(_, bb)| {
594607
// Reaching `unreachable` is UB so assume it doesn't happen.
595608
bb.terminator().kind != TerminatorKind::Unreachable
@@ -603,16 +616,16 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
603616
})
604617
.peekable();
605618

606-
let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(targets[0]);
619+
let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(&targets_and_values[0]);
607620
let mut all_successors_equivalent = StatementEquality::TrivialEqual;
608621

609622
// All successor basic blocks must be equal or contain statements that are pairwise considered equal.
610-
for ((bb_l_idx,bb_l), (bb_r_idx,bb_r)) in iter_bbs_reachable.tuple_windows() {
623+
for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
611624
let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
612625
&& bb_l.terminator().kind == bb_r.terminator().kind;
613626
let statement_check = || {
614627
bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
615-
let stmt_equality = self.statement_equality(*adt_matched_on, &l, bb_l_idx, &r, bb_r_idx, self.tcx.sess.opts.debugging_opts.mir_opt_level);
628+
let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);
616629
if matches!(stmt_equality, StatementEquality::NotEqual) {
617630
// short circuit
618631
None
@@ -634,7 +647,7 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
634647
// statements are trivially equal, so just take first
635648
trace!("Statements are trivially equal");
636649
Some(SimplifyBranchSameOptimization {
637-
bb_to_goto: bb_first,
650+
bb_to_goto: bb_first.target,
638651
bb_to_opt_terminator: bb_idx,
639652
})
640653
}
@@ -669,10 +682,9 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
669682
&self,
670683
adt_matched_on: Place<'tcx>,
671684
x: &Statement<'tcx>,
672-
x_bb_idx: BasicBlock,
685+
x_target_and_value: &SwitchTargetAndValue,
673686
y: &Statement<'tcx>,
674-
y_bb_idx: BasicBlock,
675-
mir_opt_level: usize,
687+
y_target_and_value: &SwitchTargetAndValue,
676688
) -> StatementEquality {
677689
let helper = |rhs: &Rvalue<'tcx>,
678690
place: &Place<'tcx>,
@@ -691,13 +703,7 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
691703

692704
match rhs {
693705
Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
694-
// FIXME(76803): This logic is currently broken because it does not take into
695-
// account the current discriminant value.
696-
if mir_opt_level > 2 {
697-
StatementEquality::ConsideredEqual(side_to_choose)
698-
} else {
699-
StatementEquality::NotEqual
700-
}
706+
StatementEquality::ConsideredEqual(side_to_choose)
701707
}
702708
_ => {
703709
trace!(
@@ -717,16 +723,20 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
717723
(
718724
StatementKind::Assign(box (_, rhs)),
719725
StatementKind::SetDiscriminant { place, variant_index },
720-
) => {
726+
)
727+
// we need to make sure that the switch value that targets the bb with SetDiscriminant (y), is the same as the variant index
728+
if Some(variant_index.index() as u128) == y_target_and_value.value => {
721729
// choose basic block of x, as that has the assign
722-
helper(rhs, place, variant_index, x_bb_idx)
730+
helper(rhs, place, variant_index, x_target_and_value.target)
723731
}
724732
(
725733
StatementKind::SetDiscriminant { place, variant_index },
726734
StatementKind::Assign(box (_, rhs)),
727-
) => {
735+
)
736+
// we need to make sure that the switch value that targets the bb with SetDiscriminant (x), is the same as the variant index
737+
if Some(variant_index.index() as u128) == x_target_and_value.value => {
728738
// choose basic block of y, as that has the assign
729-
helper(rhs, place, variant_index, y_bb_idx)
739+
helper(rhs, place, variant_index, y_target_and_value.target)
730740
}
731741
_ => {
732742
trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);

src/test/mir-opt/76803_regression.encode.SimplifyBranchSame.diff

+7-9
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,20 @@
88

99
bb0: {
1010
_2 = discriminant(_1); // scope 0 at $DIR/76803_regression.rs:12:9: 12:16
11-
- switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16
12-
+ goto -> bb1; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16
11+
switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16
1312
}
1413

1514
bb1: {
1615
_0 = move _1; // scope 0 at $DIR/76803_regression.rs:13:14: 13:15
17-
- goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6
18-
+ goto -> bb2; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6
16+
goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6
1917
}
2018

2119
bb2: {
22-
- discriminant(_0) = 1; // scope 0 at $DIR/76803_regression.rs:12:20: 12:27
23-
- goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6
24-
- }
25-
-
26-
- bb3: {
20+
discriminant(_0) = 1; // scope 0 at $DIR/76803_regression.rs:12:20: 12:27
21+
goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6
22+
}
23+
24+
bb3: {
2725
return; // scope 0 at $DIR/76803_regression.rs:15:2: 15:2
2826
}
2927
}

src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff

+15-12
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,27 @@
1313

1414
bb0: {
1515
_2 = discriminant(_1); // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
16-
switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
16+
- switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
17+
+ goto -> bb1; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
1718
}
1819

1920
bb1: {
20-
discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21
21-
goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
22-
}
23-
24-
bb2: {
25-
unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12
26-
}
27-
28-
bb3: {
21+
- discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21
22+
- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
23+
- }
24+
-
25+
- bb2: {
26+
- unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12
27+
- }
28+
-
29+
- bb3: {
2930
_0 = move _1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
30-
goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
31+
- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
32+
+ goto -> bb2; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
3133
}
3234

33-
bb4: {
35+
- bb4: {
36+
+ bb2: {
3437
return; // scope 0 at $DIR/simplify-arm.rs:14:2: 14:2
3538
}
3639
}

0 commit comments

Comments
 (0)