@@ -4670,6 +4670,66 @@ InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
4670
4670
return LegalizationCost * LT.first ;
4671
4671
}
4672
4672
4673
+ InstructionCost AArch64TTIImpl::getPartialReductionCost (
4674
+ unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
4675
+ ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
4676
+ TTI::PartialReductionExtendKind OpBExtend,
4677
+ std::optional<unsigned > BinOp) const {
4678
+ InstructionCost Invalid = InstructionCost::getInvalid ();
4679
+ InstructionCost Cost (TTI::TCC_Basic);
4680
+
4681
+ if (Opcode != Instruction::Add)
4682
+ return Invalid;
4683
+
4684
+ if (InputTypeA != InputTypeB)
4685
+ return Invalid;
4686
+
4687
+ EVT InputEVT = EVT::getEVT (InputTypeA);
4688
+ EVT AccumEVT = EVT::getEVT (AccumType);
4689
+
4690
+ if (VF.isScalable () && !ST->isSVEorStreamingSVEAvailable ())
4691
+ return Invalid;
4692
+ if (VF.isFixed () && (!ST->isNeonAvailable () || !ST->hasDotProd ()))
4693
+ return Invalid;
4694
+
4695
+ if (InputEVT == MVT::i8) {
4696
+ switch (VF.getKnownMinValue ()) {
4697
+ default :
4698
+ return Invalid;
4699
+ case 8 :
4700
+ if (AccumEVT == MVT::i32)
4701
+ Cost *= 2 ;
4702
+ else if (AccumEVT != MVT::i64)
4703
+ return Invalid;
4704
+ break ;
4705
+ case 16 :
4706
+ if (AccumEVT == MVT::i64)
4707
+ Cost *= 2 ;
4708
+ else if (AccumEVT != MVT::i32)
4709
+ return Invalid;
4710
+ break ;
4711
+ }
4712
+ } else if (InputEVT == MVT::i16) {
4713
+ // FIXME: Allow i32 accumulator but increase cost, as we would extend
4714
+ // it to i64.
4715
+ if (VF.getKnownMinValue () != 8 || AccumEVT != MVT::i64)
4716
+ return Invalid;
4717
+ } else
4718
+ return Invalid;
4719
+
4720
+ // AArch64 supports lowering mixed extensions to a usdot but only if the
4721
+ // i8mm or sve/streaming features are available.
4722
+ if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
4723
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8 () &&
4724
+ !ST->isSVEorStreamingSVEAvailable ()))
4725
+ return Invalid;
4726
+
4727
+ if (!BinOp || *BinOp != Instruction::Mul)
4728
+ return Invalid;
4729
+
4730
+ return Cost;
4731
+ }
4732
+
4673
4733
InstructionCost AArch64TTIImpl::getShuffleCost (
4674
4734
TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef<int > Mask,
4675
4735
TTI::TargetCostKind CostKind, int Index, VectorType *SubTp,
0 commit comments