-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[InstCombine] Offset both sides of an equality icmp #134086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Yingwei Zheng (dtcxzyw) ChangesCloses #134024 Full diff: https://github.com/llvm/llvm-project/pull/134086.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 55afe1258159a..a2bc855e155cd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5808,6 +5808,131 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
return nullptr;
}
+/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
+using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
+static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
+ bool AllowRecursion) {
+ Instruction *Inst = dyn_cast<Instruction>(V);
+ if (!Inst)
+ return;
+ Constant *C;
+
+ switch (Inst->getOpcode()) {
+ case Instruction::Add:
+ if (match(Inst->getOperand(1), m_ImmConstant(C)))
+ if (Constant *NegC = ConstantExpr::getNeg(C))
+ Offsets.emplace_back(Instruction::Add, NegC);
+ break;
+ case Instruction::Xor:
+ Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
+ Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
+ break;
+ case Instruction::Select:
+ if (AllowRecursion) {
+ Value *TrueV = Inst->getOperand(1);
+ if (TrueV->hasOneUse())
+ collectOffsetOp(TrueV, Offsets, /*AllowRecursion=*/false);
+ Value *FalseV = Inst->getOperand(2);
+ if (FalseV->hasOneUse())
+ collectOffsetOp(FalseV, Offsets, /*AllowRecursion=*/false);
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+enum class OffsetKind { Invalid, Value, Select };
+
+struct OffsetResult {
+ OffsetKind Kind;
+ Value *V0, *V1, *V2;
+
+ static OffsetResult invalid() {
+ return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
+ }
+ static OffsetResult value(Value *V) {
+ return {OffsetKind::Value, V, nullptr, nullptr};
+ }
+ static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
+ return {OffsetKind::Select, Cond, TrueV, FalseV};
+ }
+ bool isValid() const { return Kind != OffsetKind::Invalid; }
+ Value *materialize(InstCombiner::BuilderTy &Builder) const {
+ switch (Kind) {
+ case OffsetKind::Invalid:
+ llvm_unreachable("Invalid offset result");
+ case OffsetKind::Value:
+ return V0;
+ case OffsetKind::Select:
+ return Builder.CreateSelect(V0, V1, V2);
+ default:
+ llvm_unreachable("Unknown offset result kind");
+ }
+ }
+};
+
+/// Offset both sides of an equality icmp to see if we can save some
+/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
+/// Note: This operation should not introduce poison.
+static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
+ InstCombiner::BuilderTy &Builder,
+ const SimplifyQuery &SQ) {
+ assert(I.isEquality() && "Expected an equality icmp");
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (!Op0->getType()->isIntOrIntVectorTy())
+ return nullptr;
+
+ SmallVector<OffsetOp, 4> OffsetOps;
+ if (Op0->hasOneUse())
+ collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
+ if (Op1->hasOneUse())
+ collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
+
+ auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
+ Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
+ // Avoid infinite loops by checking if RHS is an identity for the BinOp.
+ if (!Simplified || Simplified == V)
+ return nullptr;
+ return Simplified;
+ };
+
+ auto ApplyOffset = [&](Value *V, unsigned BinOpc,
+ Value *RHS) -> OffsetResult {
+ if (auto *Sel = dyn_cast<SelectInst>(V)) {
+ if (!Sel->hasOneUse())
+ return OffsetResult::invalid();
+ Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
+ if (!TrueVal)
+ return OffsetResult::invalid();
+ Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
+ if (!FalseVal)
+ return OffsetResult::invalid();
+ return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
+ }
+ if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
+ return OffsetResult::value(Simplified);
+ return OffsetResult::invalid();
+ };
+
+ for (auto [BinOp, RHS] : OffsetOps) {
+ auto BinOpc = static_cast<unsigned>(BinOp);
+
+ auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
+ if (!Op0Result.isValid())
+ continue;
+ auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
+ if (!Op1Result.isValid())
+ continue;
+
+ Value *NewLHS = Op0Result.materialize(Builder);
+ Value *NewRHS = Op1Result.materialize(Builder);
+ return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
@@ -6054,6 +6179,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
: ConstantInt::getNullValue(A->getType()));
}
+ if (auto *Res = foldICmpEqualityWithOffset(
+ I, Builder, getSimplifyQuery().getWithInstruction(&I)))
+ return Res;
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll
index a8cdf80948a84..1a41c1f3e1045 100644
--- a/llvm/test/Transforms/InstCombine/icmp-add.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-add.ll
@@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {
define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_undef2(
-; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_non_splat2(
-; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
index b8e8ed0eaf1da..b0b633fba06be 100644
--- a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
define <2 x i1> @foo3(<2 x i8> %x) {
; CHECK-LABEL: @foo3(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
entry:
diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll
index 0bdbc88ba67c6..c909673481bb4 100644
--- a/llvm/test/Transforms/InstCombine/icmp-select.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-select.ll
@@ -628,3 +628,133 @@ define i1 @icmp_slt_select(i1 %cond, i32 %a, i32 %b) {
%res = icmp slt i32 %lhs, %rhs
ret i1 %res
}
+
+define i1 @discr_eq(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
+; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %add1 = add i8 %a, -2
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %add1, i8 1
+ %add2 = add i8 %b, -2
+ %cmp2 = icmp ugt i8 %b, 1
+ %sel2 = select i1 %cmp2, i8 %add2, i8 1
+ %res = icmp eq i8 %sel1, %sel2
+ ret i1 %res
+}
+
+define i1 @discr_ne(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_ne(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
+; CHECK-NEXT: [[RES:%.*]] = icmp ne i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %add1 = add i8 %a, -2
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %add1, i8 1
+ %add2 = add i8 %b, -2
+ %cmp2 = icmp ugt i8 %b, 1
+ %sel2 = select i1 %cmp2, i8 %add2, i8 1
+ %res = icmp ne i8 %sel1, %sel2
+ ret i1 %res
+}
+
+define i1 @discr_xor_eq(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_xor_eq(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
+; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 -4
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 -4
+; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %xor1 = xor i8 %a, -3
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %xor1, i8 1
+ %xor2 = xor i8 %b, -3
+ %cmp2 = icmp ugt i8 %b, 1
+ %sel2 = select i1 %cmp2, i8 %xor2, i8 1
+ %res = icmp eq i8 %sel1, %sel2
+ ret i1 %res
+}
+
+define i1 @discr_eq_simple(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_simple(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
+; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[ADD2:%.*]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %add1 = add i8 %a, -2
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %add1, i8 1
+ %add2 = add i8 %b, -2
+ %res = icmp eq i8 %sel1, %add2
+ ret i1 %res
+}
+
+; Negative tests
+
+define i1 @discr_eq_multi_use(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_multi_use(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
+; CHECK-NEXT: call void @use(i8 [[SEL1]])
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
+; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %add1 = add i8 %a, -2
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %add1, i8 1
+ call void @use(i8 %sel1)
+ %add2 = add i8 %b, -2
+ %cmp2 = icmp ugt i8 %b, 1
+ %sel2 = select i1 %cmp2, i8 %add2, i8 1
+ %res = icmp eq i8 %sel1, %sel2
+ ret i1 %res
+}
+
+define i1 @discr_eq_failed_to_simplify(i8 %a, i8 %b) {
+; CHECK-LABEL: @discr_eq_failed_to_simplify(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -3
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
+; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
+; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
+; CHECK-NEXT: ret i1 [[RES]]
+;
+entry:
+ %add1 = add i8 %a, -3
+ %cmp1 = icmp ugt i8 %a, 1
+ %sel1 = select i1 %cmp1, i8 %add1, i8 1
+ %add2 = add i8 %b, -2
+ %cmp2 = icmp ugt i8 %b, 1
+ %sel2 = select i1 %cmp2, i8 %add2, i8 1
+ %res = icmp eq i8 %sel1, %sel2
+ ret i1 %res
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not familiar with InstCombine, but it looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this subsume any existing transforms? (I guess the poison check means it doesn't?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the common diffs on llvm-opt-benchmark is something like this:
@g = external global i8
define i1 @src(ptr %p) {
%i = ptrtoint ptr %p to i64
%sub = sub i64 %i, ptrtoint (ptr @g to i64)
%cmp = icmp eq i64 %sub, -1
ret i1 %cmp
}
define i1 @tgt(ptr %p) {
%cmp = icmp eq ptr %p, inttoptr (i64 add (i64 ptrtoint (ptr @g to i64), i64 -1) to ptr)
ret i1 %cmp
}
I think we probably don't want to transform that case. It doesn't really simplify anything, just shift something from an instruction into a constant expression. (And may break pointer subtraction optimizations...)
|
||
switch (Inst->getOpcode()) { | ||
case Instruction::Add: | ||
if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could generalize this by deferring this check to ApplyOffset, in which case we could also use implied poison reasoning instead. I think we did that, we could handle cases like a + x == b + x
regardless of whether x
is poison, and that would maybe allow dropping some existing folds. Possibly better left for later though.
Proof: https://alive2.llvm.org/ce/z/zQ2UW4
Closes #134024