@@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5764
5764
: nhwcToNchw4DTransposeDims);
5765
5765
}
5766
5766
5767
+ void
5768
+ unsqueezeInputOutputFor2dPool (RankedTensorType inputTy, Value &input,
5769
+ Type &outputTy, Location loc,
5770
+ ConversionPatternRewriter &rewriter) const {
5771
+ // 1d pool AtenOps mapped to TosaOp will already have the data in 4D format,
5772
+ // here we can have 3D data only if the AtenOp itself is a 2d pool op with
5773
+ // data in HWC format.
5774
+
5775
+ // Unsqueeze input tensor in HWC format to NHWC format to be
5776
+ // compatible with tosa::AvgPool2dOp, batch is made explicitly 1.
5777
+ SmallVector<int64_t > rank4Shape (inputTy.getShape ());
5778
+ assert (inputTy.getRank () == 3 &&
5779
+ " Expected input to be atleast 3 dimensional." );
5780
+ rank4Shape.insert (rank4Shape.begin (), 1 );
5781
+ input = rewriter.create <tosa::ReshapeOp>(
5782
+ loc,
5783
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
5784
+ inputTy.getElementType ()),
5785
+ input, tosa::getTosaConstShape (rewriter, loc, rank4Shape));
5786
+
5787
+ // Unsqueeze output type
5788
+ auto outRankedTy = cast<RankedTensorType>(outputTy);
5789
+ assert (outRankedTy.getRank () == 3 &&
5790
+ " Expected output rank to be same as input." );
5791
+ SmallVector<int64_t > rank4ShapeOut (outRankedTy.getShape ());
5792
+ rank4ShapeOut.insert (rank4ShapeOut.begin (), 1 );
5793
+ outputTy = outRankedTy.clone (rank4ShapeOut);
5794
+ }
5795
+
5767
5796
LogicalResult
5768
5797
matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
5769
5798
ConversionPatternRewriter &rewriter) const override {
@@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5778
5807
return rewriter.notifyMatchFailure (
5779
5808
op, " Failed to process inputs for pooling" );
5780
5809
5810
+ // input has already been verified to be RankedTensorType
5811
+ auto inputTy = cast<RankedTensorType>(input.getType ());
5812
+ if (inputTy.getRank () != 4 ) {
5813
+ unsqueezeInputOutputFor2dPool (inputTy, input, outputTy, op->getLoc (),
5814
+ rewriter);
5815
+ }
5816
+
5781
5817
Value pooledOutput;
5782
5818
static_assert (std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
5783
5819
std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
@@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5805
5841
op, rewriter, pooledOutput);
5806
5842
5807
5843
Value result = transposedOutput;
5808
- auto resultTy = dyn_cast<TensorType>(
5844
+ auto resultTy = cast<TensorType>(result.getType ());
5845
+ auto expectedResultTy = dyn_cast<TensorType>(
5809
5846
OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
5810
5847
op.getType ()));
5811
5848
5812
- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5813
- std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
5814
- auto resultShape = resultTy.getShape ();
5815
- auto resultElemTy = resultTy.getElementType ();
5849
+ if (resultTy.getRank () != expectedResultTy.getRank ()) {
5850
+ auto resultShape = expectedResultTy.getShape ();
5851
+ auto resultElemTy = expectedResultTy.getElementType ();
5816
5852
5817
5853
result = rewriter.create <tosa::ReshapeOp>(
5818
5854
op->getLoc (),
@@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5823
5859
makeShapeTorchCompatible (resultShape)));
5824
5860
}
5825
5861
5826
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultTy , result);
5862
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, expectedResultTy , result);
5827
5863
5828
5864
return success ();
5829
5865
}
@@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp
5851
5887
auto inputElemTy = inputTy.getElementType ();
5852
5888
5853
5889
// Rank sanity check.
5854
- if (inputTy. getRank () != 4 && inputRank != 3 )
5890
+ if (inputRank != 4 && inputRank != 3 )
5855
5891
return rewriter.notifyMatchFailure (
5856
5892
op, " NCHW->NHWC transpose requires 3D or 4D tensor" );
5857
5893
@@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
5944
5980
inputElemTy);
5945
5981
}
5946
5982
5983
+ template <typename AtenOpT>
5984
+ void expandPoolParams (AtenOpT op, SmallVectorImpl<int64_t > ¶ms,
5985
+ int64_t val) {
5986
+ // Expand pooling parameter (kernel, stride) to size 2 to be compatible with
5987
+ // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5988
+ if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5989
+ std::is_same<AtenOpT, AtenAvgPool1dOp>())
5990
+ params.push_back (val);
5991
+
5992
+ if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
5993
+ std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
5994
+ if (params.size () == 1 )
5995
+ params.push_back (params[0 ]);
5996
+ }
5997
+ }
5998
+
5947
5999
// Checks the validity of pooling parameters and stores them in the respective
5948
6000
// vector. Also, gets the output type for the pooling op.
5949
6001
template <typename AtenOpT, typename tosaOp>
@@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
5969
6021
m_TorchListOfConstantInts (kernelSizeInts)))
5970
6022
return rewriter.notifyMatchFailure (
5971
6023
op, " Non-const kernel_size for pooling op unsupported" );
5972
-
5973
- // Expand kernel size parameter to size 2 to be compatible with
5974
- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5975
- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5976
- std::is_same<AtenOpT, AtenAvgPool1dOp>())
5977
- kernelSizeInts.push_back (1 );
6024
+ expandPoolParams (op, kernelSizeInts, 1 );
5978
6025
5979
6026
if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5980
6027
return rewriter.notifyMatchFailure (
@@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters(
5986
6033
if (strideInts.empty ()) {
5987
6034
strideInts.assign (kernelSizeInts);
5988
6035
} else {
5989
- // Expand stride parameter to size 2 to be compatible with
5990
- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
5991
- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5992
- std::is_same<AtenOpT, AtenAvgPool1dOp>())
5993
- strideInts.push_back (1 );
6036
+ expandPoolParams (op, strideInts, 1 );
5994
6037
}
5995
6038
5996
6039
if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5997
6040
return rewriter.notifyMatchFailure (
5998
6041
op, " Non-const padding factor for pooling op unsupported" );
5999
-
6000
- // Expand padding parameter to size 2 to be compatible with
6001
- // tosa::MaxPool2dOp or tosa::AvgPool2dOp
6002
- if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
6003
- std::is_same<AtenOpT, AtenAvgPool1dOp>())
6004
- paddingInts.push_back (0 );
6042
+ expandPoolParams (op, paddingInts, 0 );
6005
6043
6006
6044
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
6007
6045
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
@@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
6033
6071
return rewriter.notifyMatchFailure (
6034
6072
op, " only support constant bool ceil_mode for pooling op" );
6035
6073
6074
+ expandPoolParams (op, dilationArray, 1 );
6036
6075
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6037
6076
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6038
6077
ceilMode);
0 commit comments