Skip to content

Commit a733c1f

Browse files
authored
[AArch64][NFC] Move getPartialReductionCost into cpp file (#123370)
The function getPartialReductionCost is already quite large and is likely to grow in size as we add support for more cases in future. Therefore, I think it's best to move this into the cpp file.
1 parent 9b853f6 commit a733c1f

File tree

2 files changed

+61
-56
lines changed

2 files changed

+61
-56
lines changed

Diff for: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -4670,6 +4670,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
46704670
return LegalizationCost * LT.first;
46714671
}
46724672

4673+
InstructionCost AArch64TTIImpl::getPartialReductionCost(
4674+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
4675+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
4676+
TTI::PartialReductionExtendKind OpBExtend,
4677+
std::optional<unsigned> BinOp) const {
4678+
InstructionCost Invalid = InstructionCost::getInvalid();
4679+
InstructionCost Cost(TTI::TCC_Basic);
4680+
4681+
if (Opcode != Instruction::Add)
4682+
return Invalid;
4683+
4684+
if (InputTypeA != InputTypeB)
4685+
return Invalid;
4686+
4687+
EVT InputEVT = EVT::getEVT(InputTypeA);
4688+
EVT AccumEVT = EVT::getEVT(AccumType);
4689+
4690+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
4691+
return Invalid;
4692+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
4693+
return Invalid;
4694+
4695+
if (InputEVT == MVT::i8) {
4696+
switch (VF.getKnownMinValue()) {
4697+
default:
4698+
return Invalid;
4699+
case 8:
4700+
if (AccumEVT == MVT::i32)
4701+
Cost *= 2;
4702+
else if (AccumEVT != MVT::i64)
4703+
return Invalid;
4704+
break;
4705+
case 16:
4706+
if (AccumEVT == MVT::i64)
4707+
Cost *= 2;
4708+
else if (AccumEVT != MVT::i32)
4709+
return Invalid;
4710+
break;
4711+
}
4712+
} else if (InputEVT == MVT::i16) {
4713+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
4714+
// it to i64.
4715+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
4716+
return Invalid;
4717+
} else
4718+
return Invalid;
4719+
4720+
// AArch64 supports lowering mixed extensions to a usdot but only if the
4721+
// i8mm or sve/streaming features are available.
4722+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
4723+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
4724+
!ST->isSVEorStreamingSVEAvailable()))
4725+
return Invalid;
4726+
4727+
if (!BinOp || *BinOp != Instruction::Mul)
4728+
return Invalid;
4729+
4730+
return Cost;
4731+
}
4732+
46734733
InstructionCost AArch64TTIImpl::getShuffleCost(
46744734
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int> Mask,
46754735
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,

Diff for: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+1-56
Original file line numberDiff line numberDiff line change
@@ -367,62 +367,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
367367
Type *AccumType, ElementCount VF,
368368
TTI::PartialReductionExtendKind OpAExtend,
369369
TTI::PartialReductionExtendKind OpBExtend,
370-
std::optional<unsigned> BinOp) const {
371-
372-
InstructionCost Invalid = InstructionCost::getInvalid();
373-
InstructionCost Cost(TTI::TCC_Basic);
374-
375-
if (Opcode != Instruction::Add)
376-
return Invalid;
377-
378-
if (InputTypeA != InputTypeB)
379-
return Invalid;
380-
381-
EVT InputEVT = EVT::getEVT(InputTypeA);
382-
EVT AccumEVT = EVT::getEVT(AccumType);
383-
384-
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
385-
return Invalid;
386-
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
387-
return Invalid;
388-
389-
if (InputEVT == MVT::i8) {
390-
switch (VF.getKnownMinValue()) {
391-
default:
392-
return Invalid;
393-
case 8:
394-
if (AccumEVT == MVT::i32)
395-
Cost *= 2;
396-
else if (AccumEVT != MVT::i64)
397-
return Invalid;
398-
break;
399-
case 16:
400-
if (AccumEVT == MVT::i64)
401-
Cost *= 2;
402-
else if (AccumEVT != MVT::i32)
403-
return Invalid;
404-
break;
405-
}
406-
} else if (InputEVT == MVT::i16) {
407-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
408-
// it to i64.
409-
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
410-
return Invalid;
411-
} else
412-
return Invalid;
413-
414-
// AArch64 supports lowering mixed extensions to a usdot but only if the
415-
// i8mm or sve/streaming features are available.
416-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
417-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
418-
!ST->isSVEorStreamingSVEAvailable()))
419-
return Invalid;
420-
421-
if (!BinOp || *BinOp != Instruction::Mul)
422-
return Invalid;
423-
424-
return Cost;
425-
}
370+
std::optional<unsigned> BinOp) const;
426371

427372
bool enableOrderedReductions() const { return true; }
428373

0 commit comments

Comments
 (0)