Skip to content

Commit 278cd32

Browse files
committed
[InstCombine] Offset both sides of an equality icmp
1 parent 0d8c53a commit 278cd32

File tree

4 files changed

+147
-29
lines changed

4 files changed

+147
-29
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

+129
Original file line numberDiff line numberDiff line change
@@ -5808,6 +5808,131 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
58085808
return nullptr;
58095809
}
58105810

5811+
/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
5812+
using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
5813+
static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
5814+
bool AllowRecursion) {
5815+
Instruction *Inst = dyn_cast<Instruction>(V);
5816+
if (!Inst)
5817+
return;
5818+
Constant *C;
5819+
5820+
switch (Inst->getOpcode()) {
5821+
case Instruction::Add:
5822+
if (match(Inst->getOperand(1), m_ImmConstant(C)))
5823+
if (Constant *NegC = ConstantExpr::getNeg(C))
5824+
Offsets.emplace_back(Instruction::Add, NegC);
5825+
break;
5826+
case Instruction::Xor:
5827+
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
5828+
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
5829+
break;
5830+
case Instruction::Select:
5831+
if (AllowRecursion) {
5832+
Value *TrueV = Inst->getOperand(1);
5833+
if (TrueV->hasOneUse())
5834+
collectOffsetOp(TrueV, Offsets, /*AllowRecursion=*/false);
5835+
Value *FalseV = Inst->getOperand(2);
5836+
if (FalseV->hasOneUse())
5837+
collectOffsetOp(FalseV, Offsets, /*AllowRecursion=*/false);
5838+
}
5839+
break;
5840+
default:
5841+
break;
5842+
}
5843+
}
5844+
5845+
enum class OffsetKind { Invalid, Value, Select };
5846+
5847+
struct OffsetResult {
5848+
OffsetKind Kind;
5849+
Value *V0, *V1, *V2;
5850+
5851+
static OffsetResult invalid() {
5852+
return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
5853+
}
5854+
static OffsetResult value(Value *V) {
5855+
return {OffsetKind::Value, V, nullptr, nullptr};
5856+
}
5857+
static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
5858+
return {OffsetKind::Select, Cond, TrueV, FalseV};
5859+
}
5860+
bool isValid() const { return Kind != OffsetKind::Invalid; }
5861+
Value *materialize(InstCombiner::BuilderTy &Builder) const {
5862+
switch (Kind) {
5863+
case OffsetKind::Invalid:
5864+
llvm_unreachable("Invalid offset result");
5865+
case OffsetKind::Value:
5866+
return V0;
5867+
case OffsetKind::Select:
5868+
return Builder.CreateSelect(V0, V1, V2);
5869+
default:
5870+
llvm_unreachable("Unknown offset result kind");
5871+
}
5872+
}
5873+
};
5874+
5875+
/// Offset both sides of an equality icmp to see if we can save some
5876+
/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
5877+
/// Note: This operation should not introduce poison.
5878+
static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
5879+
InstCombiner::BuilderTy &Builder,
5880+
const SimplifyQuery &SQ) {
5881+
assert(I.isEquality() && "Expected an equality icmp");
5882+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
5883+
if (!Op0->getType()->isIntOrIntVectorTy())
5884+
return nullptr;
5885+
5886+
SmallVector<OffsetOp, 4> OffsetOps;
5887+
if (Op0->hasOneUse())
5888+
collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
5889+
if (Op1->hasOneUse())
5890+
collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
5891+
5892+
auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
5893+
Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
5894+
// Avoid infinite loops by checking if RHS is an identity for the BinOp.
5895+
if (!Simplified || Simplified == V)
5896+
return nullptr;
5897+
return Simplified;
5898+
};
5899+
5900+
auto ApplyOffset = [&](Value *V, unsigned BinOpc,
5901+
Value *RHS) -> OffsetResult {
5902+
if (auto *Sel = dyn_cast<SelectInst>(V)) {
5903+
if (!Sel->hasOneUse())
5904+
return OffsetResult::invalid();
5905+
Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
5906+
if (!TrueVal)
5907+
return OffsetResult::invalid();
5908+
Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
5909+
if (!FalseVal)
5910+
return OffsetResult::invalid();
5911+
return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
5912+
}
5913+
if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
5914+
return OffsetResult::value(Simplified);
5915+
return OffsetResult::invalid();
5916+
};
5917+
5918+
for (auto [BinOp, RHS] : OffsetOps) {
5919+
auto BinOpc = static_cast<unsigned>(BinOp);
5920+
5921+
auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
5922+
if (!Op0Result.isValid())
5923+
continue;
5924+
auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
5925+
if (!Op1Result.isValid())
5926+
continue;
5927+
5928+
Value *NewLHS = Op0Result.materialize(Builder);
5929+
Value *NewRHS = Op1Result.materialize(Builder);
5930+
return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
5931+
}
5932+
5933+
return nullptr;
5934+
}
5935+
58115936
Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
58125937
if (!I.isEquality())
58135938
return nullptr;
@@ -6054,6 +6179,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
60546179
: ConstantInt::getNullValue(A->getType()));
60556180
}
60566181

6182+
if (auto *Res = foldICmpEqualityWithOffset(
6183+
I, Builder, getSimplifyQuery().getWithInstruction(&I)))
6184+
return Res;
6185+
60576186
return nullptr;
60586187
}
60596188

llvm/test/Transforms/InstCombine/icmp-add.ll

+2-4
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {
23802380

23812381
define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
23822382
; CHECK-LABEL: @icmp_eq_add_undef2(
2383-
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
2384-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
2383+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
23852384
; CHECK-NEXT: ret <2 x i1> [[CMP]]
23862385
;
23872386
%add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
23912390

23922391
define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
23932392
; CHECK-LABEL: @icmp_eq_add_non_splat2(
2394-
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
2395-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
2393+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
23962394
; CHECK-NEXT: ret <2 x i1> [[CMP]]
23972395
;
23982396
%add = add <2 x i32> %a, <i32 5, i32 5>

llvm/test/Transforms/InstCombine/icmp-equality-xor.ll

+1-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
136136
define <2 x i1> @foo3(<2 x i8> %x) {
137137
; CHECK-LABEL: @foo3(
138138
; CHECK-NEXT: entry:
139-
; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
140-
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
139+
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
141140
; CHECK-NEXT: ret <2 x i1> [[CMP]]
142141
;
143142
entry:

llvm/test/Transforms/InstCombine/icmp-select.ll

+15-23
Original file line numberDiff line numberDiff line change
@@ -632,12 +632,10 @@ define i1 @icmp_slt_select(i1 %cond, i32 %a, i32 %b) {
632632
define i1 @discr_eq(i8 %a, i8 %b) {
633633
; CHECK-LABEL: @discr_eq(
634634
; CHECK-NEXT: entry:
635-
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2
636-
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
637-
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
638-
; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
639-
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
640-
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
635+
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
636+
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
637+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
638+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
641639
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
642640
; CHECK-NEXT: ret i1 [[RES]]
643641
;
@@ -655,12 +653,10 @@ entry:
655653
define i1 @discr_ne(i8 %a, i8 %b) {
656654
; CHECK-LABEL: @discr_ne(
657655
; CHECK-NEXT: entry:
658-
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2
659-
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
660-
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
661-
; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
662-
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
663-
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
656+
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
657+
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
658+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
659+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
664660
; CHECK-NEXT: [[RES:%.*]] = icmp ne i8 [[SEL1]], [[SEL2]]
665661
; CHECK-NEXT: ret i1 [[RES]]
666662
;
@@ -678,12 +674,10 @@ entry:
678674
define i1 @discr_xor_eq(i8 %a, i8 %b) {
679675
; CHECK-LABEL: @discr_xor_eq(
680676
; CHECK-NEXT: entry:
681-
; CHECK-NEXT: [[XOR1:%.*]] = xor i8 [[A:%.*]], -3
682-
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
683-
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[XOR1]], i8 1
684-
; CHECK-NEXT: [[XOR2:%.*]] = xor i8 [[B:%.*]], -3
685-
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
686-
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[XOR2]], i8 1
677+
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
678+
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
679+
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 -4
680+
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 -4
687681
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
688682
; CHECK-NEXT: ret i1 [[RES]]
689683
;
@@ -701,11 +695,9 @@ entry:
701695
define i1 @discr_eq_simple(i8 %a, i8 %b) {
702696
; CHECK-LABEL: @discr_eq_simple(
703697
; CHECK-NEXT: entry:
704-
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2
705-
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
706-
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
707-
; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
708-
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[ADD2]]
698+
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
699+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
700+
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[ADD2:%.*]]
709701
; CHECK-NEXT: ret i1 [[RES]]
710702
;
711703
entry:

0 commit comments

Comments
 (0)