Skip to content

Commit

Permalink
Reland "[LoopVectorizer] Add support for partial reductions" (llvm#12…
Browse files Browse the repository at this point in the history
…0721)

This re-lands the reverted llvm#92418 

When the VF is small enough so that dividing the VF by the scaling
factor results in 1, the reduction phi execution thinks the VF is scalar
and sets the reduction's output as a scalar value, tripping assertions
expecting a vector value. The latest commit in this PR fixes that by
using `State.VF` in the scalar check, rather than the divided VF.

---------

Co-authored-by: Nicholas Guy <[email protected]>
  • Loading branch information
SamTebbs33 and NickGuy-Arm authored Dec 24, 2024
1 parent b2073fb commit c858bf6
Show file tree
Hide file tree
Showing 16 changed files with 3,927 additions and 30 deletions.
39 changes: 39 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
/// for IR-level transformations.
class TargetTransformInfo {
public:
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };

/// Get the kind of extension that an instruction represents.
static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I);

/// Construct a TTI object using a type implementing the \c Concept
/// API below.
///
Expand Down Expand Up @@ -1280,6 +1286,18 @@ class TargetTransformInfo {
/// \return if target want to issue a prefetch in address space \p AS.
bool shouldPrefetchAddressSpace(unsigned AS) const;

/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces a vector to another of 4 times fewer elements.
InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const;

/// \return The maximum interleave factor that any transform should try to
/// perform for this target. This number depends on the level of parallelism
/// and the number of execution units in the CPU.
Expand Down Expand Up @@ -2107,6 +2125,18 @@ class TargetTransformInfo::Concept {
/// \return if target want to issue a prefetch in address space \p AS.
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;

/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces a vector to another of 4 times fewer elements.
virtual InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const = 0;

virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
virtual InstructionCost getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
Expand Down Expand Up @@ -2786,6 +2816,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.shouldPrefetchAddressSpace(AS);
}

InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const override {
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
OpAExtend, OpBExtend, BinOp);
}

unsigned getMaxInterleaveFactor(ElementCount VF) override {
return Impl.getMaxInterleaveFactor(VF);
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
bool enableWritePrefetching() const { return false; }
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const {
return InstructionCost::getInvalid();
}

unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }

InstructionCost getArithmeticInstrCost(
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,14 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
return TTIImpl->shouldPrefetchAddressSpace(AS);
}

InstructionCost TargetTransformInfo::getPartialReductionCost(
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const {
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
OpAExtend, OpBExtend, BinOp);
}

unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
return TTIImpl->getMaxInterleaveFactor(VF);
}
Expand Down Expand Up @@ -974,6 +982,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
return Cost;
}

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
if (isa<SExtInst>(I))
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
return PR_None;
}

TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
Expand Down
56 changes: 56 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/InstructionCost.h"
#include <cstdint>
#include <optional>

Expand Down Expand Up @@ -357,6 +358,61 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return BaseT::isLegalNTLoad(DataType, Alignment);
}

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const {

InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);

if (Opcode != Instruction::Add)
return Invalid;

EVT InputEVT = EVT::getEVT(InputType);
EVT AccumEVT = EVT::getEVT(AccumType);

if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
return Invalid;
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
return Invalid;

if (InputEVT == MVT::i8) {
switch (VF.getKnownMinValue()) {
default:
return Invalid;
case 8:
if (AccumEVT == MVT::i32)
Cost *= 2;
else if (AccumEVT != MVT::i64)
return Invalid;
break;
case 16:
if (AccumEVT == MVT::i64)
Cost *= 2;
else if (AccumEVT != MVT::i32)
return Invalid;
break;
}
} else if (InputEVT == MVT::i16) {
// FIXME: Allow i32 accumulator but increase cost, as we would extend
// it to i64.
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
return Invalid;
} else
return Invalid;

if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
return Invalid;

if (!BinOp || (*BinOp) != Instruction::Mul)
return Invalid;

return Cost;
}

bool enableOrderedReductions() const { return true; }

InstructionCost getInterleavedMemoryOpCost(
Expand Down
136 changes: 132 additions & 4 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7605,6 +7605,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
}
continue;
}
// The VPlan-based cost model is more accurate for partial reduction and
// comparing against the legacy cost isn't desirable.
if (isa<VPPartialReductionRecipe>(&R))
return true;
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
}
Expand Down Expand Up @@ -8827,6 +8831,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
return Recipe;
}

/// Find all possible partial reductions in the loop and track all of those that
/// are valid so recipes can be formed later.
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Find all possible partial reductions.
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
PartialReductionChains;
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
getScaledReduction(Phi, RdxDesc, Range))
PartialReductionChains.push_back(*Pair);

// A partial reduction is invalid if any of its extends are used by
// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.

// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
for (const auto &[PartialRdx, _] : PartialReductionChains)
PartialReductionBinOps.insert(PartialRdx.BinOp);

auto ExtendIsOnlyUsedByPartialReductions =
[&PartialReductionBinOps](Instruction *Extend) {
return all_of(Extend->users(), [&](const User *U) {
return PartialReductionBinOps.contains(U);
});
};

// Check if each use of a chain's two extends is a partial reduction
// and only add those that don't have non-partial reduction users.
for (auto Pair : PartialReductionChains) {
PartialReductionChain Chain = Pair.first;
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
}
}

std::optional<std::pair<PartialReductionChain, unsigned>>
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
const RecurrenceDescriptor &Rdx,
VFRange &Range) {
// TODO: Allow scaling reductions when predicating. The select at
// the end of the loop chooses between the phi value and most recent
// reduction result, both of which have different VFs to the active lane
// mask when scaling.
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
return std::nullopt;

auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
if (!Update)
return std::nullopt;

Value *Op = Update->getOperand(0);
if (Op == PHI)
Op = Update->getOperand(1);

auto *BinOp = dyn_cast<BinaryOperator>(Op);
if (!BinOp || !BinOp->hasOneUse())
return std::nullopt;

using namespace llvm::PatternMatch;
Value *A, *B;
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
return std::nullopt;

Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));

// Check that the extends extend from the same type.
if (A->getType() != B->getType())
return std::nullopt;

TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);

PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);

unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
A->getType()->getPrimitiveSizeInBits());

if (LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), PHI->getType(), VF,
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
return Cost.isValid();
},
Range))
return std::make_pair(Chain, TargetScaleFactor);

return std::nullopt;
}

VPRecipeBase *
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
ArrayRef<VPValue *> Operands,
Expand All @@ -8851,9 +8952,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
Legal->getReductionVars().find(Phi)->second;
assert(RdxDesc.getRecurrenceStartValue() ==
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
CM.isInLoopReduction(Phi),
CM.useOrderedReductions(RdxDesc));

// If the PHI is used by a partial reduction, set the scale factor.
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
unsigned ScaleFactor = Pair ? Pair->second : 1;
PhiRecipe = new VPReductionPHIRecipe(
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
CM.useOrderedReductions(RdxDesc), ScaleFactor);
} else {
// TODO: Currently fixed-order recurrences are modeled as chains of
// first-order recurrences. If there are no users of the intermediate
Expand Down Expand Up @@ -8885,6 +8991,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
return tryToWidenMemory(Instr, Operands, Range);

if (getScaledReductionForInstr(Instr))
return tryToCreatePartialReduction(Instr, Operands);

if (!shouldWiden(Instr, Range))
return nullptr;

Expand All @@ -8905,6 +9014,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
return tryToWiden(Instr, Operands, VPBB);
}

VPRecipeBase *
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
ArrayRef<VPValue *> Operands) {
assert(Operands.size() == 2 &&
"Unexpected number of operands for partial reduction");

VPValue *BinOp = Operands[0];
VPValue *Phi = Operands[1];
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
std::swap(BinOp, Phi);

return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
Reduction);
}

void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
ElementCount MaxVF) {
assert(OrigLoop->isInnermost() && "Inner loop expected.");
Expand Down Expand Up @@ -9222,7 +9346,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);

VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
Builder);

// ---------------------------------------------------------------------------
// Pre-construction: record ingredients whose recipes we'll need to further
Expand Down Expand Up @@ -9268,6 +9393,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
return Legal->blockNeedsPredication(BB) || NeedsBlends;
});

RecipeBuilder.collectScaledReductions(Range);

auto *MiddleVPBB = Plan->getMiddleBlock();
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
Expand Down
Loading

0 comments on commit c858bf6

Please sign in to comment.