Skip to content

[InstCombine] Offset both sides of an equality icmp #134086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 29, 2025
126 changes: 126 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5808,6 +5808,128 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
return nullptr;
}

/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
bool AllowRecursion) {
Instruction *Inst = dyn_cast<Instruction>(V);
if (!Inst || !Inst->hasOneUse())
return;

switch (Inst->getOpcode()) {
case Instruction::Add:
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(1));
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(0));
break;
case Instruction::Sub:
Offsets.emplace_back(Instruction::Add, Inst->getOperand(1));
break;
case Instruction::Xor:
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
break;
case Instruction::Select:
if (AllowRecursion) {
collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false);
collectOffsetOp(Inst->getOperand(2), Offsets, /*AllowRecursion=*/false);
}
break;
default:
break;
}
}

enum class OffsetKind { Invalid, Value, Select };

struct OffsetResult {
OffsetKind Kind;
Value *V0, *V1, *V2;

static OffsetResult invalid() {
return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
}
static OffsetResult value(Value *V) {
return {OffsetKind::Value, V, nullptr, nullptr};
}
static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
return {OffsetKind::Select, Cond, TrueV, FalseV};
}
bool isValid() const { return Kind != OffsetKind::Invalid; }
Value *materialize(InstCombiner::BuilderTy &Builder) const {
switch (Kind) {
case OffsetKind::Invalid:
llvm_unreachable("Invalid offset result");
case OffsetKind::Value:
return V0;
case OffsetKind::Select:
return Builder.CreateSelect(V0, V1, V2);
}
}
};

/// Offset both sides of an equality icmp to see if we can save some
/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
/// Note: This operation should not introduce poison.
static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &SQ) {
assert(I.isEquality() && "Expected an equality icmp");
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (!Op0->getType()->isIntOrIntVectorTy())
return nullptr;

SmallVector<OffsetOp, 4> OffsetOps;
collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);

auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
// Avoid infinite loops by checking if RHS is an identity for the BinOp.
if (!Simplified || Simplified == V)
return nullptr;
// Reject constant expressions as they don't simplify things.
if (isa<Constant>(Simplified) && !match(Simplified, m_ImmConstant()))
return nullptr;
// Check if the transformation introduces poison.
return impliesPoison(RHS, V) ? Simplified : nullptr;
};

auto ApplyOffset = [&](Value *V, unsigned BinOpc,
Value *RHS) -> OffsetResult {
if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!Sel->hasOneUse())
return OffsetResult::invalid();
Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
if (!TrueVal)
return OffsetResult::invalid();
Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
if (!FalseVal)
return OffsetResult::invalid();
return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
}
if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
return OffsetResult::value(Simplified);
return OffsetResult::invalid();
};

for (auto [BinOp, RHS] : OffsetOps) {
auto BinOpc = static_cast<unsigned>(BinOp);

auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
if (!Op0Result.isValid())
continue;
auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
if (!Op1Result.isValid())
continue;

Value *NewLHS = Op0Result.materialize(Builder);
Value *NewRHS = Op1Result.materialize(Builder);
return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
}

return nullptr;
}

Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
Expand Down Expand Up @@ -6054,6 +6176,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
: ConstantInt::getNullValue(A->getType()));
}

if (auto *Res = foldICmpEqualityWithOffset(
I, Builder, getSimplifyQuery().getWithInstruction(&I)))
return Res;

return nullptr;
}

Expand Down
6 changes: 2 additions & 4 deletions llvm/test/Transforms/InstCombine/icmp-add.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {

define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_undef2(
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
Expand All @@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {

define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_non_splat2(
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
define <2 x i1> @foo3(<2 x i8> %x) {
; CHECK-LABEL: @foo3(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
entry:
Expand Down
Loading