Skip to content

EXPERIMENT [MIR-OPT] Add a pass that replicates simple branches into predecessors #111574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ mod remove_storage_markers;
mod remove_uninit_drops;
mod remove_unneeded_drops;
mod remove_zsts;
mod replicate_branches;
mod required_consts;
mod reveal_all;
mod separate_const_switch;
Expand Down Expand Up @@ -554,7 +555,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&remove_unneeded_drops::RemoveUnneededDrops,
&sroa::ScalarReplacementOfAggregates,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
// inst simplify is after MatchBranchSimplification to clean up Ne(_1, false)
&replicate_branches::ReplicateBranches,
&multiple_return_terminators::MultipleReturnTerminators,
&instsimplify::InstSimplify,
&separate_const_switch::SeparateConstSwitch,
Expand Down
104 changes: 104 additions & 0 deletions compiler/rustc_mir_transform/src/replicate_branches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//! Looks for basic blocks that are just a simple 2-way branch,
//! and replicates them into their predecessors.
//!
//! That sounds wasteful, but it's very common that the predecessor just set
//! whatever we're about to branch on, and this makes that easier to collapse
//! later. And if not, well, an extra branch is no big deal.

use crate::MirPass;

use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

use super::simplify::simplify_cfg;

pub struct ReplicateBranches;

impl<'tcx> MirPass<'tcx> for ReplicateBranches {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() >= 2
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!("Running ReplicateBranches on `{:?}`", body.source);

let mut source_and_target_blocks = vec![];

for (bb, data) in body.basic_blocks.iter_enumerated() {
if data.is_cleanup {
continue;
}

let TerminatorKind::SwitchInt { discr, targets } = &data.terminator().kind
else { continue };

let Some(place) = discr.place() else { continue };

// Only replicate simple branches. That means either:
// - A specific target plus a potentially-reachable otherwise, or
// - Two specific targets and the otherwise is unreachable.
if !(targets.iter().len() < 2
|| (targets.iter().len() <= 2
&& body.basic_blocks[targets.otherwise()].is_empty_unreachable()))
{
continue;
}

// Only replicate blocks that branch on a discriminant.
let mut found_discriminant = false;
for stmt in &data.statements {
match &stmt.kind {
StatementKind::StorageDead { .. } | StatementKind::StorageLive { .. } => {}
StatementKind::Assign(place_and_rvalue)
if place_and_rvalue.0 == place
&& matches!(place_and_rvalue.1, Rvalue::Discriminant(..)) =>
{
found_discriminant = true;
}
_ => continue,
}
}

// This currently doesn't duplicate ordinary boolean checks as that regresses
// various tests that seem to depend on the existing structure.
if !found_discriminant {
continue;
}

// Only replicate to a small number of `goto` predecessors.
let preds = &body.basic_blocks.predecessors()[bb];
if preds.len() > 2
|| !preds.iter().copied().all(|p| {
let pdata = &body.basic_blocks[p];
matches!(pdata.terminator().kind, TerminatorKind::Goto { .. })
&& !pdata.is_cleanup
})
{
continue;
}

for p in preds {
source_and_target_blocks.push((*p, bb));
}
}

if source_and_target_blocks.is_empty() {
return;
}

let basic_blocks = body.basic_blocks.as_mut();
for (source, target) in source_and_target_blocks {
debug_assert_eq!(
basic_blocks[source].terminator().kind,
TerminatorKind::Goto { target },
);

let (source, target) = basic_blocks.pick2_mut(source, target);
source.statements.extend(target.statements.iter().cloned());
source.terminator = target.terminator.clone();
}

// The target blocks should be unused now, so cleanup after ourselves.
simplify_cfg(tcx, body);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,51 @@
+ scope 1 (inlined core::num::<impl u16>::unchecked_shl) { // at $DIR/unchecked_shifts.rs:11:7: 11:23
+ debug self => _3; // in scope 1 at $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
+ debug rhs => _4; // in scope 1 at $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
+ let mut _5: u16; // in scope 1 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ let mut _6: (u32,); // in scope 1 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ let mut _5: (u32,); // in scope 1 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ let mut _6: u32; // in scope 1 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ scope 2 {
+ scope 3 (inlined core::num::<impl u16>::unchecked_shl::conv) { // at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ debug x => _6; // in scope 3 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ let mut _7: std::option::Option<u16>; // in scope 3 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ let mut _8: std::result::Result<u16, std::num::TryFromIntError>; // in scope 3 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ scope 4 {
+ scope 5 (inlined <u32 as TryInto<u16>>::try_into) { // at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ debug self => _6; // in scope 5 at $SRC_DIR/core/src/convert/mod.rs:LL:COL
+ scope 6 (inlined convert::num::<impl TryFrom<u32> for u16>::try_from) { // at $SRC_DIR/core/src/convert/mod.rs:LL:COL
+ debug u => _6; // in scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ let mut _9: bool; // in scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ let mut _10: u32; // in scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ let mut _11: u16; // in scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ }
+ }
+ scope 7 (inlined Result::<u16, TryFromIntError>::ok) { // at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ debug self => _8; // in scope 7 at $SRC_DIR/core/src/result.rs:LL:COL
+ let _12: u16; // in scope 7 at $SRC_DIR/core/src/result.rs:LL:COL
+ scope 8 {
+ debug x => _12; // in scope 8 at $SRC_DIR/core/src/result.rs:LL:COL
+ }
+ }
+ scope 9 (inlined #[track_caller] Option::<u16>::unwrap_unchecked) { // at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ debug self => _7; // in scope 9 at $SRC_DIR/core/src/option.rs:LL:COL
+ let mut _13: &std::option::Option<u16>; // in scope 9 at $SRC_DIR/core/src/option.rs:LL:COL
+ let _14: u16; // in scope 9 at $SRC_DIR/core/src/option.rs:LL:COL
+ scope 10 {
+ debug val => _14; // in scope 10 at $SRC_DIR/core/src/option.rs:LL:COL
+ }
+ scope 11 {
+ scope 13 (inlined unreachable_unchecked) { // at $SRC_DIR/core/src/option.rs:LL:COL
+ scope 14 {
+ scope 15 (inlined unreachable_unchecked::runtime) { // at $SRC_DIR/core/src/intrinsics.rs:LL:COL
+ }
+ }
+ }
+ }
+ scope 12 (inlined Option::<u16>::is_some) { // at $SRC_DIR/core/src/option.rs:LL:COL
+ debug self => _13; // in scope 12 at $SRC_DIR/core/src/option.rs:LL:COL
+ }
+ }
+ }
+ }
+ }
+ }

Expand All @@ -22,30 +64,61 @@
StorageLive(_4); // scope 0 at $DIR/unchecked_shifts.rs:+1:21: +1:22
_4 = _2; // scope 0 at $DIR/unchecked_shifts.rs:+1:21: +1:22
- _0 = core::num::<impl u16>::unchecked_shl(move _3, move _4) -> bb1; // scope 0 at $DIR/unchecked_shifts.rs:+1:5: +1:23
+ StorageLive(_5); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_6); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _6 = (_4,); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _5 = core::num::<impl u16>::unchecked_shl::conv(move (_6.0: u32)) -> bb1; // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
// mir::Constant
- // mir::Constant
- // + span: $DIR/unchecked_shifts.rs:11:7: 11:20
- // + literal: Const { ty: unsafe fn(u16, u32) -> u16 {core::num::<impl u16>::unchecked_shl}, val: Value(<ZST>) }
+ // + span: $SRC_DIR/core/src/num/mod.rs:LL:COL
+ // + literal: Const { ty: fn(u32) -> u16 {core::num::<impl u16>::unchecked_shl::conv}, val: Value(<ZST>) }
+ StorageLive(_14); // scope 0 at $DIR/unchecked_shifts.rs:+1:7: +1:23
+ StorageLive(_5); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _5 = (_4,); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_6); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _6 = move (_5.0: u32); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_7); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_8); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_9); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageLive(_10); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ _10 = const 65535_u32; // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ _9 = Gt(_6, move _10); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageDead(_10); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ switchInt(move _9) -> [0: bb3, otherwise: bb2]; // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
}

bb1: {
+ StorageDead(_14); // scope 0 at $DIR/unchecked_shifts.rs:+1:7: +1:23
StorageDead(_4); // scope 0 at $DIR/unchecked_shifts.rs:+1:22: +1:23
StorageDead(_3); // scope 0 at $DIR/unchecked_shifts.rs:+1:22: +1:23
return; // scope 0 at $DIR/unchecked_shifts.rs:+2:2: +2:2
+ }
+
+ bb2: {
+ StorageDead(_9); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageLive(_12); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageDead(_12); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageDead(_8); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_13); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ unreachable; // scope 7 at $SRC_DIR/core/src/result.rs:LL:COL
+ }
+
+ bb3: {
+ StorageLive(_11); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ _11 = _6 as u16 (IntToInt); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ _8 = Result::<u16, TryFromIntError>::Ok(move _11); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageDead(_11); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageDead(_9); // scope 6 at $SRC_DIR/core/src/convert/num.rs:LL:COL
+ StorageLive(_12); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _12 = move ((_8 as Ok).0: u16); // scope 7 at $SRC_DIR/core/src/result.rs:LL:COL
+ _7 = Option::<u16>::Some(move _12); // scope 8 at $SRC_DIR/core/src/result.rs:LL:COL
+ StorageDead(_12); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageDead(_8); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageLive(_13); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _14 = move ((_7 as Some).0: u16); // scope 9 at $SRC_DIR/core/src/option.rs:LL:COL
+ StorageDead(_13); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageDead(_7); // scope 4 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ StorageDead(_6); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _0 = unchecked_shl::<u16>(_3, move _5) -> [return: bb2, unwind unreachable]; // scope 2 at $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
+ StorageDead(_5); // scope 2 at $SRC_DIR/core/src/num/mod.rs:LL:COL
+ _0 = unchecked_shl::<u16>(_3, move _14) -> [return: bb1, unwind unreachable]; // scope 2 at $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
+ // mir::Constant
+ // + span: $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
+ // + literal: Const { ty: unsafe extern "rust-intrinsic" fn(u16, u16) -> u16 {unchecked_shl::<u16>}, val: Value(<ZST>) }
+ }
+
+ bb2: {
+ StorageDead(_5); // scope 2 at $SRC_DIR/core/src/num/uint_macros.rs:LL:COL
StorageDead(_4); // scope 0 at $DIR/unchecked_shifts.rs:+1:22: +1:23
StorageDead(_3); // scope 0 at $DIR/unchecked_shifts.rs:+1:22: +1:23
return; // scope 0 at $DIR/unchecked_shifts.rs:+2:2: +2:2
}
}

Loading