diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 026ee035fcf54..ca2c904c4e3a7 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -1299,6 +1299,32 @@ inline BinaryOpc_match m_Not(const ValTy &V) { return m_Xor(V, m_AllOnes()); } +struct SpecificNeg_match { + SDValue V; + + explicit SpecificNeg_match(SDValue V) : V(V) {} + + template + bool match(const MatchContext &Ctx, SDValue N) { + if (sd_context_match(N, Ctx, m_Neg(m_Specific(V)))) + return true; + + return ISD::matchBinaryPredicate( + V, N, + [](ConstantSDNode *LHS, ConstantSDNode *RHS) { + return APInt::isSameValue(-LHS->getAPIntValue(), + RHS->getAPIntValue()); + }, + /*AllowUndefs=*/false, /*AllowTypeMismatch=*/false); + } +}; + +/// Match a negation of a specific value V, either as sub(0, V) or as +/// constant(s) that are the negation of V's constant(s). +inline SpecificNeg_match m_SpecificNeg(SDValue V) { + return SpecificNeg_match(V); +} + template struct ReassociatableOpc_match { unsigned Opcode; std::tuple Patterns; diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 1afc034dd7b9e..67295d02929a2 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -1062,3 +1062,52 @@ TEST_F(SelectionDAGPatternMatchTest, MatchSelectCC) { m_Specific(TVal), m_Specific(FVal), m_CondCode(CC)))); } + +TEST_F(SelectionDAGPatternMatchTest, MatchSpecificNeg) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto VecVT = EVT::getVectorVT(Context, Int32VT, 4); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + + using namespace SDPatternMatch; + + SDValue Neg = DAG->getNegative(Op0, DL, Int32VT); + EXPECT_TRUE(sd_match(Neg, m_SpecificNeg(Op0))); + EXPECT_TRUE(sd_match(Neg, m_Neg(m_Specific(Op0)))); + + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); + EXPECT_FALSE(sd_match(Neg, m_SpecificNeg(Op1))); + + SDValue Const5 = DAG->getConstant(5, DL, Int32VT); + SDValue ConstNeg5 = DAG->getConstant(APInt(32, -5, true), DL, Int32VT); + EXPECT_TRUE(sd_match(ConstNeg5, m_SpecificNeg(Const5))); + EXPECT_TRUE(sd_match(Const5, m_SpecificNeg(ConstNeg5))); + + SDValue Const3 = DAG->getConstant(3, DL, Int32VT); + EXPECT_FALSE(sd_match(ConstNeg5, m_SpecificNeg(Const3))); + + SDValue VecConst5 = DAG->getSplatBuildVector(VecVT, DL, Const5); + SDValue VecConstNeg5 = DAG->getSplatBuildVector(VecVT, DL, ConstNeg5); + EXPECT_TRUE(sd_match(VecConstNeg5, m_SpecificNeg(VecConst5))); + EXPECT_TRUE(sd_match(VecConst5, m_SpecificNeg(VecConstNeg5))); + + SDValue Const1 = DAG->getConstant(1, DL, Int32VT); + SDValue Const2 = DAG->getConstant(2, DL, Int32VT); + SDValue ConstNeg1 = DAG->getConstant(APInt(32, -1, true), DL, Int32VT); + SDValue ConstNeg2 = DAG->getConstant(APInt(32, -2, true), DL, Int32VT); + SDValue ConstNeg3 = DAG->getConstant(APInt(32, -3, true), DL, Int32VT); + SmallVector PosOps = {Const1, Const2, Const5, Const3}; + SmallVector NegOps = {ConstNeg1, ConstNeg2, ConstNeg5, ConstNeg3}; + SDValue VecPos = DAG->getBuildVector(VecVT, DL, PosOps); + SDValue VecNeg = DAG->getBuildVector(VecVT, DL, NegOps); + EXPECT_TRUE(sd_match(VecNeg, m_SpecificNeg(VecPos))); + EXPECT_TRUE(sd_match(VecPos, m_SpecificNeg(VecNeg))); + + SmallVector WrongOps = {ConstNeg1, ConstNeg2, Const5, ConstNeg5}; + SDValue VecWrong = DAG->getBuildVector(VecVT, DL, WrongOps); + EXPECT_FALSE(sd_match(VecWrong, m_SpecificNeg(VecPos))); + + SDValue Zero = DAG->getConstant(0, DL, Int32VT); + EXPECT_TRUE(sd_match(Zero, m_SpecificNeg(Zero))); +}