Skip to content

[InstCombine] Offset both sides of an equality icmp #134086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Apr 2, 2025

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

Closes #134024


Full diff: https://github.com/llvm/llvm-project/pull/134086.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+129)
  • (modified) llvm/test/Transforms/InstCombine/icmp-add.ll (+2-4)
  • (modified) llvm/test/Transforms/InstCombine/icmp-equality-xor.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/icmp-select.ll (+130)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 55afe1258159a..a2bc855e155cd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5808,6 +5808,131 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
   return nullptr;
 }
 
+/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
+using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
+static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
+                            bool AllowRecursion) {
+  Instruction *Inst = dyn_cast<Instruction>(V);
+  if (!Inst)
+    return;
+  Constant *C;
+
+  switch (Inst->getOpcode()) {
+  case Instruction::Add:
+    if (match(Inst->getOperand(1), m_ImmConstant(C)))
+      if (Constant *NegC = ConstantExpr::getNeg(C))
+        Offsets.emplace_back(Instruction::Add, NegC);
+    break;
+  case Instruction::Xor:
+    Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
+    Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
+    break;
+  case Instruction::Select:
+    if (AllowRecursion) {
+      Value *TrueV = Inst->getOperand(1);
+      if (TrueV->hasOneUse())
+        collectOffsetOp(TrueV, Offsets, /*AllowRecursion=*/false);
+      Value *FalseV = Inst->getOperand(2);
+      if (FalseV->hasOneUse())
+        collectOffsetOp(FalseV, Offsets, /*AllowRecursion=*/false);
+    }
+    break;
+  default:
+    break;
+  }
+}
+
+enum class OffsetKind { Invalid, Value, Select };
+
+struct OffsetResult {
+  OffsetKind Kind;
+  Value *V0, *V1, *V2;
+
+  static OffsetResult invalid() {
+    return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
+  }
+  static OffsetResult value(Value *V) {
+    return {OffsetKind::Value, V, nullptr, nullptr};
+  }
+  static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
+    return {OffsetKind::Select, Cond, TrueV, FalseV};
+  }
+  bool isValid() const { return Kind != OffsetKind::Invalid; }
+  Value *materialize(InstCombiner::BuilderTy &Builder) const {
+    switch (Kind) {
+    case OffsetKind::Invalid:
+      llvm_unreachable("Invalid offset result");
+    case OffsetKind::Value:
+      return V0;
+    case OffsetKind::Select:
+      return Builder.CreateSelect(V0, V1, V2);
+    default:
+      llvm_unreachable("Unknown offset result kind");
+    }
+  }
+};
+
+/// Offset both sides of an equality icmp to see if we can save some
+/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
+/// Note: This operation should not introduce poison.
+static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
+                                               InstCombiner::BuilderTy &Builder,
+                                               const SimplifyQuery &SQ) {
+  assert(I.isEquality() && "Expected an equality icmp");
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+  if (!Op0->getType()->isIntOrIntVectorTy())
+    return nullptr;
+
+  SmallVector<OffsetOp, 4> OffsetOps;
+  if (Op0->hasOneUse())
+    collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
+  if (Op1->hasOneUse())
+    collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
+
+  auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
+    Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
+    // Avoid infinite loops by checking if RHS is an identity for the BinOp.
+    if (!Simplified || Simplified == V)
+      return nullptr;
+    return Simplified;
+  };
+
+  auto ApplyOffset = [&](Value *V, unsigned BinOpc,
+                         Value *RHS) -> OffsetResult {
+    if (auto *Sel = dyn_cast<SelectInst>(V)) {
+      if (!Sel->hasOneUse())
+        return OffsetResult::invalid();
+      Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
+      if (!TrueVal)
+        return OffsetResult::invalid();
+      Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
+      if (!FalseVal)
+        return OffsetResult::invalid();
+      return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
+    }
+    if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
+      return OffsetResult::value(Simplified);
+    return OffsetResult::invalid();
+  };
+
+  for (auto [BinOp, RHS] : OffsetOps) {
+    auto BinOpc = static_cast<unsigned>(BinOp);
+
+    auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
+    if (!Op0Result.isValid())
+      continue;
+    auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
+    if (!Op1Result.isValid())
+      continue;
+
+    Value *NewLHS = Op0Result.materialize(Builder);
+    Value *NewRHS = Op1Result.materialize(Builder);
+    return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
   if (!I.isEquality())
     return nullptr;
@@ -6054,6 +6179,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
                                   : ConstantInt::getNullValue(A->getType()));
   }
 
+  if (auto *Res = foldICmpEqualityWithOffset(
+          I, Builder, getSimplifyQuery().getWithInstruction(&I)))
+    return Res;
+
   return nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll
index a8cdf80948a84..1a41c1f3e1045 100644
--- a/llvm/test/Transforms/InstCombine/icmp-add.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-add.ll
@@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {
 
 define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
 ; CHECK-LABEL: @icmp_eq_add_undef2(
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
 
 define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
 ; CHECK-LABEL: @icmp_eq_add_non_splat2(
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = add <2 x i32> %a, <i32 5, i32 5>
diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
index b8e8ed0eaf1da..b0b633fba06be 100644
--- a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
 define <2 x i1> @foo3(<2 x i8> %x) {
 ; CHECK-LABEL: @foo3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
 entry:
diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll
index 0bdbc88ba67c6..c909673481bb4 100644
--- a/llvm/test/Transforms/InstCombine/icmp-select.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-select.ll
@@ -628,3 +628,133 @@ define i1 @icmp_slt_select(i1 %cond, i32 %a, i32 %b) {
   %res = icmp slt i32 %lhs, %rhs
   ret i1 %res
 }
+
+define i1 @discr_eq(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %add1 = add i8 %a, -2
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %add1, i8 1
+  %add2 = add i8 %b, -2
+  %cmp2 = icmp ugt i8 %b, 1
+  %sel2 = select i1 %cmp2, i8 %add2, i8 1
+  %res = icmp eq i8 %sel1, %sel2
+  ret i1 %res
+}
+
+define i1 @discr_ne(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_ne(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
+; CHECK-NEXT:    [[RES:%.*]] = icmp ne i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %add1 = add i8 %a, -2
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %add1, i8 1
+  %add2 = add i8 %b, -2
+  %cmp2 = icmp ugt i8 %b, 1
+  %sel2 = select i1 %cmp2, i8 %add2, i8 1
+  %res = icmp ne i8 %sel1, %sel2
+  ret i1 %res
+}
+
+define i1 @discr_xor_eq(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_xor_eq(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT:    [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 -4
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 -4
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %xor1 = xor i8 %a, -3
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %xor1, i8 1
+  %xor2 = xor i8 %b, -3
+  %cmp2 = icmp ugt i8 %b, 1
+  %sel2 = select i1 %cmp2, i8 %xor2, i8 1
+  %res = icmp eq i8 %sel1, %sel2
+  ret i1 %res
+}
+
+define i1 @discr_eq_simple(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_simple(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[SEL1]], [[ADD2:%.*]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %add1 = add i8 %a, -2
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %add1, i8 1
+  %add2 = add i8 %b, -2
+  %res = icmp eq i8 %sel1, %add2
+  ret i1 %res
+}
+
+; Negative tests
+
+define i1 @discr_eq_multi_use(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_multi_use(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD1:%.*]] = add i8 [[A:%.*]], -2
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
+; CHECK-NEXT:    call void @use(i8 [[SEL1]])
+; CHECK-NEXT:    [[ADD2:%.*]] = add i8 [[B:%.*]], -2
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %add1 = add i8 %a, -2
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %add1, i8 1
+  call void @use(i8 %sel1)
+  %add2 = add i8 %b, -2
+  %cmp2 = icmp ugt i8 %b, 1
+  %sel2 = select i1 %cmp2, i8 %add2, i8 1
+  %res = icmp eq i8 %sel1, %sel2
+  ret i1 %res
+}
+
+define i1 @discr_eq_failed_to_simplify(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_failed_to_simplify(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD1:%.*]] = add i8 [[A:%.*]], -3
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
+; CHECK-NEXT:    [[ADD2:%.*]] = add i8 [[B:%.*]], -2
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+entry:
+  %add1 = add i8 %a, -3
+  %cmp1 = icmp ugt i8 %a, 1
+  %sel1 = select i1 %cmp1, i8 %add1, i8 1
+  %add2 = add i8 %b, -2
+  %cmp2 = icmp ugt i8 %b, 1
+  %sel2 = select i1 %cmp2, i8 %add2, i8 1
+  %res = icmp eq i8 %sel1, %sel2
+  ret i1 %res
+}

@dtcxzyw dtcxzyw marked this pull request as draft April 18, 2025 09:33
@dtcxzyw dtcxzyw marked this pull request as ready for review April 18, 2025 10:25
Copy link
Member

@dianqk dianqk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with InstCombine, but it looks good to me.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this subsume any existing transforms? (I guess the poison check means it doesn't?)

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the common diffs on llvm-opt-benchmark is something like this:

@g = external global i8

define i1 @src(ptr %p) {
  %i = ptrtoint ptr %p to i64
  %sub = sub i64 %i, ptrtoint (ptr @g to i64)
  %cmp = icmp eq i64 %sub, -1
  ret i1 %cmp
}

define i1 @tgt(ptr %p) {
  %cmp = icmp eq ptr %p, inttoptr (i64 add (i64 ptrtoint (ptr @g to i64), i64 -1) to ptr)
  ret i1 %cmp
}

I think we probably don't want to transform that case. It doesn't really simplify anything, just shift something from an instruction into a constant expression. (And may break pointer subtraction optimizations...)


switch (Inst->getOpcode()) {
case Instruction::Add:
if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(1)))
Copy link
Contributor

@nikic nikic Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could generalize this by deferring this check to ApplyOffset, in which case we could also use implied poison reasoning instead. I think we did that, we could handle cases like a + x == b + x regardless of whether x is poison, and that would maybe allow dropping some existing folds. Possibly better left for later though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove common adds when icmp eqing two selects [InstCombine]
4 participants