@@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57645764 : nhwcToNchw4DTransposeDims);
57655765 }
57665766
5767+ void
5768+ unsqueezeInputOutputFor2dPool (RankedTensorType inputTy, Value &input,
5769+ Type &outputTy, Location loc,
5770+ ConversionPatternRewriter &rewriter) const {
5771+ // 1d pool AtenOps mapped to TosaOp will already have the data in 4D format,
5772+ // here we can have 3D data only if the AtenOp itself is a 2d pool op with
5773+ // data in HWC format.
5774+
5775+ // Unsqueeze input tensor in HWC format to NHWC format to be
5776+ // compatible with tosa::AvgPool2dOp, batch is made explicitly 1.
5777+ SmallVector<int64_t > rank4Shape (inputTy.getShape ());
5778+ assert (inputTy.getRank () == 3 &&
5779+ " Expected input to be atleast 3 dimensional." );
5780+ rank4Shape.insert (rank4Shape.begin (), 1 );
5781+ input = rewriter.create <tosa::ReshapeOp>(
5782+ loc,
5783+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
5784+ inputTy.getElementType ()),
5785+ input, tosa::getTosaConstShape (rewriter, loc, rank4Shape));
5786+
5787+ // Unsqueeze output type
5788+ auto outRankedTy = cast<RankedTensorType>(outputTy);
5789+ assert (outRankedTy.getRank () == 3 &&
5790+ " Expected output rank to be same as input." );
5791+ SmallVector<int64_t > rank4ShapeOut (outRankedTy.getShape ());
5792+ rank4ShapeOut.insert (rank4ShapeOut.begin (), 1 );
5793+ outputTy = outRankedTy.clone (rank4ShapeOut);
5794+ }
5795+
57675796 LogicalResult
57685797 matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
57695798 ConversionPatternRewriter &rewriter) const override {
@@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57785807 return rewriter.notifyMatchFailure (
57795808 op, " Failed to process inputs for pooling" );
57805809
5810+ // input has already been verified to be RankedTensorType
5811+ auto inputTy = cast<RankedTensorType>(input.getType ());
5812+ if (inputTy.getRank () != 4 ) {
5813+ unsqueezeInputOutputFor2dPool (inputTy, input, outputTy, op->getLoc (),
5814+ rewriter);
5815+ }
5816+
57815817 Value pooledOutput;
57825818 static_assert (std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
57835819 std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
@@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
58055841 op, rewriter, pooledOutput);
58065842
58075843 Value result = transposedOutput;
5808- auto resultTy = dyn_cast<TensorType>(
5844+ auto resultTy = cast<TensorType>(result.getType ());
5845+ auto expectedResultTy = dyn_cast<TensorType>(
58095846 OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
58105847 op.getType ()));
58115848
5812- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5813- std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
5814- auto resultShape = resultTy.getShape ();
5815- auto resultElemTy = resultTy.getElementType ();
5849+ if (resultTy.getRank () != expectedResultTy.getRank ()) {
5850+ auto resultShape = expectedResultTy.getShape ();
5851+ auto resultElemTy = expectedResultTy.getElementType ();
58165852
58175853 result = rewriter.create <tosa::ReshapeOp>(
58185854 op->getLoc (),
@@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
58235859 makeShapeTorchCompatible (resultShape)));
58245860 }
58255861
5826- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultTy , result);
5862+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, expectedResultTy , result);
58275863
58285864 return success ();
58295865 }
@@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp
58515887 auto inputElemTy = inputTy.getElementType ();
58525888
58535889 // Rank sanity check.
5854- if (inputTy. getRank () != 4 && inputRank != 3 )
5890+ if (inputRank != 4 && inputRank != 3 )
58555891 return rewriter.notifyMatchFailure (
58565892 op, " NCHW->NHWC transpose requires 3D or 4D tensor" );
58575893
@@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
59445980 inputElemTy);
59455981}
59465982
5983+ template <typename AtenOpT>
5984+ void expandPoolParams (AtenOpT op, SmallVectorImpl<int64_t > ¶ms,
5985+ int64_t val) {
5986+ // Expand pooling parameter (kernel, stride) to size 2 to be compatible with
5987+ // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5988+ if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5989+ std::is_same<AtenOpT, AtenAvgPool1dOp>())
5990+ params.push_back (val);
5991+
5992+ if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
5993+ std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
5994+ if (params.size () == 1 )
5995+ params.push_back (params[0 ]);
5996+ }
5997+ }
5998+
59475999// Checks the validity of pooling parameters and stores them in the respective
59486000// vector. Also, gets the output type for the pooling op.
59496001template <typename AtenOpT, typename tosaOp>
@@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
59696021 m_TorchListOfConstantInts (kernelSizeInts)))
59706022 return rewriter.notifyMatchFailure (
59716023 op, " Non-const kernel_size for pooling op unsupported" );
5972-
5973- // Expand kernel size parameter to size 2 to be compatible with
5974- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5975- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5976- std::is_same<AtenOpT, AtenAvgPool1dOp>())
5977- kernelSizeInts.push_back (1 );
6024+ expandPoolParams (op, kernelSizeInts, 1 );
59786025
59796026 if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
59806027 return rewriter.notifyMatchFailure (
@@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters(
59866033 if (strideInts.empty ()) {
59876034 strideInts.assign (kernelSizeInts);
59886035 } else {
5989- // Expand stride parameter to size 2 to be compatible with
5990- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5991- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5992- std::is_same<AtenOpT, AtenAvgPool1dOp>())
5993- strideInts.push_back (1 );
6036+ expandPoolParams (op, strideInts, 1 );
59946037 }
59956038
59966039 if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
59976040 return rewriter.notifyMatchFailure (
59986041 op, " Non-const padding factor for pooling op unsupported" );
5999-
6000- // Expand padding parameter to size 2 to be compatible with
6001- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
6002- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
6003- std::is_same<AtenOpT, AtenAvgPool1dOp>())
6004- paddingInts.push_back (0 );
6042+ expandPoolParams (op, paddingInts, 0 );
60056043
60066044 if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
60076045 std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
@@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
60336071 return rewriter.notifyMatchFailure (
60346072 op, " only support constant bool ceil_mode for pooling op" );
60356073
6074+ expandPoolParams (op, dilationArray, 1 );
60366075 outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
60376076 inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
60386077 ceilMode);
0 commit comments