Skip to content

When HIR auto-refs a comparison operator, clean it up in MIR #109292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1931,9 +1931,13 @@ impl<'tcx> Operand<'tcx> {
///
/// While this is unlikely in general, it's the normal case of what you'll
/// find as the `func` in a [`TerminatorKind::Call`].
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>)> {
let const_ty = self.constant()?.literal.ty();
if let ty::FnDef(def_id, substs) = *const_ty.kind() { Some((def_id, substs)) } else { None }
pub fn const_fn_def(&self) -> Option<(DefId, SubstsRef<'tcx>, Span)> {
let constant = self.constant()?;
if let ty::FnDef(def_id, substs) = *constant.literal.ty().kind() {
Some((def_id, substs, constant.span))
} else {
None
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/instcombine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
else { return };

// Only bother looking more if it's easy to know what we're calling
let Some((fn_def_id, fn_substs)) = func.const_fn_def()
let Some((fn_def_id, fn_substs, _span)) = func.const_fn_def()
else { return };

// Clone needs one subst, so we can cheaply rule out other stuff
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ mod ssa;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod simplify_ref_comparisons;
mod sroa;
mod uninhabited_enum_branching;
mod unreachable_prop;
Expand Down Expand Up @@ -497,6 +498,8 @@ fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&cleanup_post_borrowck::CleanupPostBorrowck,
&remove_noop_landing_pads::RemoveNoopLandingPads,
&simplify::SimplifyCfg::new("early-opt"),
// Adds more `Deref`s, so needs to be before `Derefer`.
&simplify_ref_comparisons::SimplifyRefComparisons,
&deref_separator::Derefer,
];

Expand Down
86 changes: 86 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_ref_comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};

/// This pass replaces `x OP y` with `*x OP *y` when `OP` is a comparison operator.
///
/// The goal is to make is so that it's never better for the user to write
/// `***x == ***y` than to write the obvious `x == y` (when `x` and `y` are
/// references and thus those do the same thing). This is particularly
/// important because the type-checker will auto-ref any comparison that's not
/// done directly on a primitive. That means that `a_ref == b_ref` doesn't
/// become `PartialEq::eq(a_ref, b_ref)`, even though that would work, but rather
/// ```no_run
/// # fn foo(a_ref: &i32, b_ref: &i32) -> bool {
/// let temp1 = &a_ref;
/// let temp2 = &b_ref;
/// PartialEq::eq(temp1, temp2)
/// # }
/// ```
/// Thus this pass means it directly calls the *interesting* `impl` directly,
/// rather than needing to monomorphize and/or inline it later. (And when this
/// comment was written in March 2023, the MIR inliner seemed to only inline
/// one level of `==`, so if the comparison is on something like `&&i32` the
/// extra forwarding impls needed to be monomorphized even in an optimized build.)
///
/// Make sure this runs before the `Derefer`, since it might add multiple levels
/// of dereferences in the `Operand`s that are arguments to the `Call`.
pub struct SimplifyRefComparisons;

impl<'tcx> MirPass<'tcx> for SimplifyRefComparisons {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Despite the method name, this is `PartialEq`, not `Eq`.
let Some(partial_eq) = tcx.lang_items().eq_trait() else { return };
let Some(partial_ord) = tcx.lang_items().partial_ord_trait() else { return };

for block in body.basic_blocks.as_mut() {
let terminator = block.terminator.as_mut().unwrap();
let TerminatorKind::Call { func, args, from_hir_call: false, .. } =
&mut terminator.kind
else { continue };

// Quickly skip unary operators
if args.len() != 2 {
continue;
}
let (Some(left_place), Some(right_place)) = (args[0].place(), args[1].place())
else { continue };

let (fn_def, fn_substs, fn_span) =
func.const_fn_def().expect("HIR operators to always call the traits directly");
let substs =
fn_substs.try_as_type_list().expect("HIR operators only have type parameters");
let [left_ty, right_ty] = *substs.as_slice() else { continue };
let (depth, new_left_ty, new_right_ty) = find_ref_depth(left_ty, right_ty);
if depth == 0 {
// Already dereffed as far as possible.
continue;
}

// Check it's a comparison, not `+`/`&`/etc.
let trait_def = tcx.trait_of_item(fn_def);
if trait_def != Some(partial_eq) && trait_def != Some(partial_ord) {
continue;
}

let derefs = vec![ProjectionElem::Deref; depth];
let new_substs = [new_left_ty.into(), new_right_ty.into()];

*func = Operand::function_handle(tcx, fn_def, new_substs, fn_span);
args[0] = Operand::Copy(left_place.project_deeper(&derefs, tcx));
args[1] = Operand::Copy(right_place.project_deeper(&derefs, tcx));
}
}
}

fn find_ref_depth<'tcx>(mut left: Ty<'tcx>, mut right: Ty<'tcx>) -> (usize, Ty<'tcx>, Ty<'tcx>) {
let mut depth = 0;
while let (ty::Ref(_, new_left, Mutability::Not), ty::Ref(_, new_right, Mutability::Not)) =
(left.kind(), right.kind())
{
depth += 1;
(left, right) = (*new_left, *new_right);
}

(depth, left, right)
}
168 changes: 168 additions & 0 deletions tests/mir-opt/simplify_cmp.multi_ref_prim.SimplifyRefComparisons.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
- // MIR for `multi_ref_prim` before SimplifyRefComparisons
+ // MIR for `multi_ref_prim` after SimplifyRefComparisons

fn multi_ref_prim(_1: &&&i32, _2: &&&i32) -> () {
debug x => _1; // in scope 0 at $DIR/simplify_cmp.rs:+0:23: +0:24
debug y => _2; // in scope 0 at $DIR/simplify_cmp.rs:+0:34: +0:35
let mut _0: (); // return place in scope 0 at $DIR/simplify_cmp.rs:+0:45: +0:45
let _3: bool; // in scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
let mut _4: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
let mut _5: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
let mut _7: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:14: +2:15
let mut _8: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+2:19: +2:20
let mut _10: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:14: +3:15
let mut _11: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
let _12: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+3:18: +3:19
let mut _14: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:14: +4:15
let mut _15: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
let _16: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+4:19: +4:20
let mut _18: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:14: +5:15
let mut _19: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
let _20: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+5:18: +5:19
let mut _22: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:14: +6:15
let mut _23: &&&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
let _24: &&&i32; // in scope 0 at $DIR/simplify_cmp.rs:+6:19: +6:20
scope 1 {
debug _a => _3; // in scope 1 at $DIR/simplify_cmp.rs:+1:9: +1:11
let _6: bool; // in scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
scope 2 {
debug _b => _6; // in scope 2 at $DIR/simplify_cmp.rs:+2:9: +2:11
let _9: bool; // in scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
scope 3 {
debug _c => _9; // in scope 3 at $DIR/simplify_cmp.rs:+3:9: +3:11
let _13: bool; // in scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
scope 4 {
debug _d => _13; // in scope 4 at $DIR/simplify_cmp.rs:+4:9: +4:11
let _17: bool; // in scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
scope 5 {
debug _e => _17; // in scope 5 at $DIR/simplify_cmp.rs:+5:9: +5:11
let _21: bool; // in scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
scope 6 {
debug _f => _21; // in scope 6 at $DIR/simplify_cmp.rs:+6:9: +6:11
}
}
}
}
}
}

bb0: {
StorageLive(_3); // scope 0 at $DIR/simplify_cmp.rs:+1:9: +1:11
StorageLive(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
_4 = &_1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:15
StorageLive(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
_5 = &_2; // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
- _3 = <&&&i32 as PartialEq>::eq(move _4, move _5) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
+ _3 = <i32 as PartialEq>::eq((*(*(*_4))), (*(*(*_5)))) -> bb1; // scope 0 at $DIR/simplify_cmp.rs:+1:14: +1:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:18:14: 18:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::eq}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::eq}, val: Value(<ZST>) }
}

bb1: {
StorageDead(_5); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
StorageDead(_4); // scope 0 at $DIR/simplify_cmp.rs:+1:19: +1:20
StorageLive(_6); // scope 1 at $DIR/simplify_cmp.rs:+2:9: +2:11
StorageLive(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
_7 = &_1; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:15
StorageLive(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
_8 = &_2; // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
- _6 = <&&&i32 as PartialEq>::ne(move _7, move _8) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
+ _6 = <i32 as PartialEq>::ne((*(*(*_7))), (*(*(*_8)))) -> bb2; // scope 1 at $DIR/simplify_cmp.rs:+2:14: +2:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:19:14: 19:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialEq>::ne}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialEq>::ne}, val: Value(<ZST>) }
}

bb2: {
StorageDead(_8); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
StorageDead(_7); // scope 1 at $DIR/simplify_cmp.rs:+2:19: +2:20
StorageLive(_9); // scope 2 at $DIR/simplify_cmp.rs:+3:9: +3:11
StorageLive(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
_10 = &_1; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:15
StorageLive(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageLive(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
_12 = &(*_2); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
_11 = &_12; // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
- _9 = <&&&i32 as PartialOrd>::lt(move _10, move _11) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
+ _9 = <i32 as PartialOrd>::lt((*(*(*_10))), (*(*(*_11)))) -> bb3; // scope 2 at $DIR/simplify_cmp.rs:+3:14: +3:19
// mir::Constant
// + span: $DIR/simplify_cmp.rs:20:14: 20:19
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::lt}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::lt}, val: Value(<ZST>) }
}

bb3: {
StorageDead(_11); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageDead(_10); // scope 2 at $DIR/simplify_cmp.rs:+3:18: +3:19
StorageDead(_12); // scope 2 at $DIR/simplify_cmp.rs:+3:19: +3:20
StorageLive(_13); // scope 3 at $DIR/simplify_cmp.rs:+4:9: +4:11
StorageLive(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
_14 = &_1; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:15
StorageLive(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageLive(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
_16 = &(*_2); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
_15 = &_16; // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
- _13 = <&&&i32 as PartialOrd>::le(move _14, move _15) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
+ _13 = <i32 as PartialOrd>::le((*(*(*_14))), (*(*(*_15)))) -> bb4; // scope 3 at $DIR/simplify_cmp.rs:+4:14: +4:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:21:14: 21:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::le}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::le}, val: Value(<ZST>) }
}

bb4: {
StorageDead(_15); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageDead(_14); // scope 3 at $DIR/simplify_cmp.rs:+4:19: +4:20
StorageDead(_16); // scope 3 at $DIR/simplify_cmp.rs:+4:20: +4:21
StorageLive(_17); // scope 4 at $DIR/simplify_cmp.rs:+5:9: +5:11
StorageLive(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
_18 = &_1; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:15
StorageLive(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageLive(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
_20 = &(*_2); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
_19 = &_20; // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
- _17 = <&&&i32 as PartialOrd>::gt(move _18, move _19) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
+ _17 = <i32 as PartialOrd>::gt((*(*(*_18))), (*(*(*_19)))) -> bb5; // scope 4 at $DIR/simplify_cmp.rs:+5:14: +5:19
// mir::Constant
// + span: $DIR/simplify_cmp.rs:22:14: 22:19
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::gt}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::gt}, val: Value(<ZST>) }
}

bb5: {
StorageDead(_19); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageDead(_18); // scope 4 at $DIR/simplify_cmp.rs:+5:18: +5:19
StorageDead(_20); // scope 4 at $DIR/simplify_cmp.rs:+5:19: +5:20
StorageLive(_21); // scope 5 at $DIR/simplify_cmp.rs:+6:9: +6:11
StorageLive(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
_22 = &_1; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:15
StorageLive(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageLive(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
_24 = &(*_2); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
_23 = &_24; // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
- _21 = <&&&i32 as PartialOrd>::ge(move _22, move _23) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
+ _21 = <i32 as PartialOrd>::ge((*(*(*_22))), (*(*(*_23)))) -> bb6; // scope 5 at $DIR/simplify_cmp.rs:+6:14: +6:20
// mir::Constant
// + span: $DIR/simplify_cmp.rs:23:14: 23:20
- // + literal: Const { ty: for<'a, 'b> fn(&'a &&&i32, &'b &&&i32) -> bool {<&&&i32 as PartialOrd>::ge}, val: Value(<ZST>) }
+ // + literal: Const { ty: for<'a, 'b> fn(&'a i32, &'b i32) -> bool {<i32 as PartialOrd>::ge}, val: Value(<ZST>) }
}

bb6: {
StorageDead(_23); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageDead(_22); // scope 5 at $DIR/simplify_cmp.rs:+6:19: +6:20
StorageDead(_24); // scope 5 at $DIR/simplify_cmp.rs:+6:20: +6:21
_0 = const (); // scope 0 at $DIR/simplify_cmp.rs:+0:45: +7:2
StorageDead(_21); // scope 5 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_17); // scope 4 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_13); // scope 3 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_9); // scope 2 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_6); // scope 1 at $DIR/simplify_cmp.rs:+7:1: +7:2
StorageDead(_3); // scope 0 at $DIR/simplify_cmp.rs:+7:1: +7:2
return; // scope 0 at $DIR/simplify_cmp.rs:+7:2: +7:2
}
}

Loading