diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 55afe1258159a..72f35f4e8e982 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5808,6 +5808,129 @@ static Instruction *foldICmpPow2Test(ICmpInst &I, return nullptr; } +/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified. +using OffsetOp = std::pair; +static void collectOffsetOp(Value *V, SmallVectorImpl &Offsets, + bool AllowRecursion) { + Instruction *Inst = dyn_cast(V); + if (!Inst || !Inst->hasOneUse()) + return; + + switch (Inst->getOpcode()) { + case Instruction::Add: + if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(1))) + Offsets.emplace_back(Instruction::Sub, Inst->getOperand(1)); + if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(0))) + Offsets.emplace_back(Instruction::Sub, Inst->getOperand(0)); + break; + case Instruction::Sub: + if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(1))) + Offsets.emplace_back(Instruction::Add, Inst->getOperand(1)); + break; + case Instruction::Xor: + if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(1))) + Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1)); + if (isGuaranteedNotToBeUndefOrPoison(Inst->getOperand(0))) + Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0)); + break; + case Instruction::Select: + if (AllowRecursion) { + collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false); + collectOffsetOp(Inst->getOperand(2), Offsets, /*AllowRecursion=*/false); + } + break; + default: + break; + } +} + +enum class OffsetKind { Invalid, Value, Select }; + +struct OffsetResult { + OffsetKind Kind; + Value *V0, *V1, *V2; + + static OffsetResult invalid() { + return {OffsetKind::Invalid, nullptr, nullptr, nullptr}; + } + static OffsetResult value(Value *V) { + return {OffsetKind::Value, V, nullptr, nullptr}; + } + static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) { + return {OffsetKind::Select, Cond, TrueV, FalseV}; + } + bool isValid() const { return Kind != OffsetKind::Invalid; } + Value *materialize(InstCombiner::BuilderTy &Builder) const { + switch (Kind) { + case OffsetKind::Invalid: + llvm_unreachable("Invalid offset result"); + case OffsetKind::Value: + return V0; + case OffsetKind::Select: + return Builder.CreateSelect(V0, V1, V2); + } + } +}; + +/// Offset both sides of an equality icmp to see if we can save some +/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z. +/// Note: This operation should not introduce poison. +static Instruction *foldICmpEqualityWithOffset(ICmpInst &I, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &SQ) { + assert(I.isEquality() && "Expected an equality icmp"); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (!Op0->getType()->isIntOrIntVectorTy()) + return nullptr; + + SmallVector OffsetOps; + collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true); + collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true); + + auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * { + Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ); + // Avoid infinite loops by checking if RHS is an identity for the BinOp. + if (!Simplified || Simplified == V) + return nullptr; + return Simplified; + }; + + auto ApplyOffset = [&](Value *V, unsigned BinOpc, + Value *RHS) -> OffsetResult { + if (auto *Sel = dyn_cast(V)) { + if (!Sel->hasOneUse()) + return OffsetResult::invalid(); + Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS); + if (!TrueVal) + return OffsetResult::invalid(); + Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS); + if (!FalseVal) + return OffsetResult::invalid(); + return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal); + } + if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS)) + return OffsetResult::value(Simplified); + return OffsetResult::invalid(); + }; + + for (auto [BinOp, RHS] : OffsetOps) { + auto BinOpc = static_cast(BinOp); + + auto Op0Result = ApplyOffset(Op0, BinOpc, RHS); + if (!Op0Result.isValid()) + continue; + auto Op1Result = ApplyOffset(Op1, BinOpc, RHS); + if (!Op1Result.isValid()) + continue; + + Value *NewLHS = Op0Result.materialize(Builder); + Value *NewRHS = Op1Result.materialize(Builder); + return new ICmpInst(I.getPredicate(), NewLHS, NewRHS); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { if (!I.isEquality()) return nullptr; @@ -6054,6 +6177,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { : ConstantInt::getNullValue(A->getType())); } + if (auto *Res = foldICmpEqualityWithOffset( + I, Builder, getSimplifyQuery().getWithInstruction(&I))) + return Res; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll index a8cdf80948a84..1a41c1f3e1045 100644 --- a/llvm/test/Transforms/InstCombine/icmp-add.ll +++ b/llvm/test/Transforms/InstCombine/icmp-add.ll @@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) { define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) { ; CHECK-LABEL: @icmp_eq_add_undef2( -; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %add = add <2 x i32> %a, @@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) { define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) { ; CHECK-LABEL: @icmp_eq_add_non_splat2( -; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; %add = add <2 x i32> %a, diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll index b8e8ed0eaf1da..b0b633fba06be 100644 --- a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll +++ b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll @@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) { define <2 x i1> @foo3(<2 x i8> %x) { ; CHECK-LABEL: @foo3( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], +; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i1> [[CMP]] ; entry: diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll index 0bdbc88ba67c6..ef6c149fd6558 100644 --- a/llvm/test/Transforms/InstCombine/icmp-select.ll +++ b/llvm/test/Transforms/InstCombine/icmp-select.ll @@ -628,3 +628,206 @@ define i1 @icmp_slt_select(i1 %cond, i32 %a, i32 %b) { %res = icmp slt i32 %lhs, %rhs ret i1 %res } + +define i1 @discr_eq(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_eq( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, -2 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %add1, i8 1 + %add2 = add i8 %b, -2 + %cmp2 = icmp ugt i8 %b, 1 + %sel2 = select i1 %cmp2, i8 %add2, i8 1 + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +define i1 @discr_ne(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_ne( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3 +; CHECK-NEXT: [[RES:%.*]] = icmp ne i8 [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, -2 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %add1, i8 1 + %add2 = add i8 %b, -2 + %cmp2 = icmp ugt i8 %b, 1 + %sel2 = select i1 %cmp2, i8 %add2, i8 1 + %res = icmp ne i8 %sel1, %sel2 + ret i1 %res +} + +define i1 @discr_xor_eq(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_xor_eq( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 -4 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 -4 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %xor1 = xor i8 %a, -3 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %xor1, i8 1 + %xor2 = xor i8 %b, -3 + %cmp2 = icmp ugt i8 %b, 1 + %sel2 = select i1 %cmp2, i8 %xor2, i8 1 + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +define i1 @discr_eq_simple(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_eq_simple( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[B:%.*]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, -2 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %add1, i8 1 + %add2 = add i8 %b, -2 + %res = icmp eq i8 %sel1, %add2 + ret i1 %res +} + +define i1 @discr_eq_add_commuted(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @discr_eq_add_commuted( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[COND1:%.*]], i8 [[B:%.*]], i8 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND2:%.*]], i8 [[C:%.*]], i8 [[B]] +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, %b + %sel1 = select i1 %cond1, i8 %add1, i8 %a + %add2 = add i8 %c, %a + %sel2 = select i1 %cond2, i8 %add2, i8 %add1 + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +define i1 @discr_eq_sub(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @discr_eq_sub( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[COND1:%.*]], i8 [[B:%.*]], i8 0 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND2:%.*]], i8 [[C:%.*]], i8 0 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %neg = sub i8 0, %a + %sub1 = sub i8 %b, %a + %sel1 = select i1 %cond1, i8 %sub1, i8 %neg + %sub2 = sub i8 %c, %a + %sel2 = select i1 %cond2, i8 %sub2, i8 %neg + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +; Negative tests + +define i1 @discr_eq_multi_use(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_eq_multi_use( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2 +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1 +; CHECK-NEXT: call void @use(i8 [[SEL1]]) +; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, -2 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %add1, i8 1 + call void @use(i8 %sel1) + %add2 = add i8 %b, -2 + %cmp2 = icmp ugt i8 %b, 1 + %sel2 = select i1 %cmp2, i8 %add2, i8 1 + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +define i1 @discr_eq_failed_to_simplify(i8 %a, i8 %b) { +; CHECK-LABEL: @discr_eq_failed_to_simplify( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -3 +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1 +; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %add1 = add i8 %a, -3 + %cmp1 = icmp ugt i8 %a, 1 + %sel1 = select i1 %cmp1, i8 %add1, i8 1 + %add2 = add i8 %b, -2 + %cmp2 = icmp ugt i8 %b, 1 + %sel2 = select i1 %cmp2, i8 %add2, i8 1 + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +} + +define <2 x i1> @discr_eq_simple_vec(<2 x i8> %a, <2 x i8> %b, i1 %cond) { +; CHECK-LABEL: @discr_eq_simple_vec( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ADD1:%.*]] = add <2 x i8> [[A:%.*]], +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND:%.*]], <2 x i8> [[ADD1]], <2 x i8> splat (i8 1) +; CHECK-NEXT: [[ADD2:%.*]] = add <2 x i8> [[B:%.*]], +; CHECK-NEXT: [[RES:%.*]] = icmp eq <2 x i8> [[SEL1]], [[ADD2]] +; CHECK-NEXT: ret <2 x i1> [[RES]] +; +entry: + %add1 = add <2 x i8> %a, + %sel1 = select i1 %cond, <2 x i8> %add1, <2 x i8> splat(i8 1) + %add2 = add <2 x i8> %b, + %res = icmp eq <2 x i8> %sel1, %add2 + ret <2 x i1> %res +} + +define i1 @discr_eq_sub_commuted(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @discr_eq_sub_commuted( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]] +; CHECK-NEXT: [[SUB1:%.*]] = sub i8 [[A]], [[B:%.*]] +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i8 [[SUB1]], i8 [[NEG]] +; CHECK-NEXT: [[SUB2:%.*]] = sub i8 [[A]], [[C:%.*]] +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i8 [[SUB2]], i8 [[NEG]] +; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]] +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %neg = sub i8 0, %a + %sub1 = sub i8 %a, %b + %sel1 = select i1 %cond1, i8 %sub1, i8 %neg + %sub2 = sub i8 %a, %c + %sel2 = select i1 %cond2, i8 %sub2, i8 %neg + %res = icmp eq i8 %sel1, %sel2 + ret i1 %res +}