Skip to content

Commit 2fb65f6

Browse files
committed
[Task] : Handle CHW input for avgpool2d.
1 parent a265d28 commit 2fb65f6

File tree

4 files changed

+189
-24
lines changed

4 files changed

+189
-24
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57645764
: nhwcToNchw4DTransposeDims);
57655765
}
57665766

5767+
void
5768+
unsqueezeInputOutputFor2dPool(RankedTensorType inputTy, Value &input,
5769+
Type &outputTy, Location loc,
5770+
ConversionPatternRewriter &rewriter) const {
5771+
// 1d pool AtenOps mapped to TosaOp will already have the data in 4D format,
5772+
// here we can have 3D data only if the AtenOp itself is a 2d pool op with
5773+
// data in HWC format.
5774+
5775+
// Unsqueeze input tensor in HWC format to NHWC format to be
5776+
// compatible with tosa::AvgPool2dOp, batch is made explicitly 1.
5777+
SmallVector<int64_t> rank4Shape(inputTy.getShape());
5778+
assert(inputTy.getRank() == 3 &&
5779+
"Expected input to be atleast 3 dimensional.");
5780+
rank4Shape.insert(rank4Shape.begin(), 1);
5781+
input = rewriter.create<tosa::ReshapeOp>(
5782+
loc,
5783+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
5784+
inputTy.getElementType()),
5785+
input, tosa::getTosaConstShape(rewriter, loc, rank4Shape));
5786+
5787+
// Unsqueeze output type
5788+
auto outRankedTy = cast<RankedTensorType>(outputTy);
5789+
assert(outRankedTy.getRank() == 3 &&
5790+
"Expected output rank to be same as input.");
5791+
SmallVector<int64_t> rank4ShapeOut(outRankedTy.getShape());
5792+
rank4ShapeOut.insert(rank4ShapeOut.begin(), 1);
5793+
outputTy = outRankedTy.clone(rank4ShapeOut);
5794+
}
5795+
57675796
LogicalResult
57685797
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
57695798
ConversionPatternRewriter &rewriter) const override {
@@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57785807
return rewriter.notifyMatchFailure(
57795808
op, "Failed to process inputs for pooling");
57805809

5810+
// input has already been verified to be RankedTensorType
5811+
auto inputTy = cast<RankedTensorType>(input.getType());
5812+
if (inputTy.getRank() != 4) {
5813+
unsqueezeInputOutputFor2dPool(inputTy, input, outputTy, op->getLoc(),
5814+
rewriter);
5815+
}
5816+
57815817
Value pooledOutput;
57825818
static_assert(std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
57835819
std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
@@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
58055841
op, rewriter, pooledOutput);
58065842

58075843
Value result = transposedOutput;
5808-
auto resultTy = dyn_cast<TensorType>(
5844+
auto resultTy = cast<TensorType>(result.getType());
5845+
auto expectedResultTy = dyn_cast<TensorType>(
58095846
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
58105847
op.getType()));
58115848

5812-
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5813-
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
5814-
auto resultShape = resultTy.getShape();
5815-
auto resultElemTy = resultTy.getElementType();
5849+
if (resultTy.getRank() != expectedResultTy.getRank()) {
5850+
auto resultShape = expectedResultTy.getShape();
5851+
auto resultElemTy = expectedResultTy.getElementType();
58165852

58175853
result = rewriter.create<tosa::ReshapeOp>(
58185854
op->getLoc(),
@@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
58235859
makeShapeTorchCompatible(resultShape)));
58245860
}
58255861

5826-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
5862+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, expectedResultTy, result);
58275863

58285864
return success();
58295865
}
@@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp
58515887
auto inputElemTy = inputTy.getElementType();
58525888

58535889
// Rank sanity check.
5854-
if (inputTy.getRank() != 4 && inputRank != 3)
5890+
if (inputRank != 4 && inputRank != 3)
58555891
return rewriter.notifyMatchFailure(
58565892
op, "NCHW->NHWC transpose requires 3D or 4D tensor");
58575893

@@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
59445980
inputElemTy);
59455981
}
59465982

5983+
template <typename AtenOpT>
5984+
void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
5985+
int64_t val) {
5986+
// Expand pooling parameter (kernel, stride) to size 2 to be compatible with
5987+
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
5988+
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5989+
std::is_same<AtenOpT, AtenAvgPool1dOp>())
5990+
params.push_back(val);
5991+
5992+
if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
5993+
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
5994+
if (params.size() == 1)
5995+
params.push_back(params[0]);
5996+
}
5997+
}
5998+
59475999
// Checks the validity of pooling parameters and stores them in the respective
59486000
// vector. Also, gets the output type for the pooling op.
59496001
template <typename AtenOpT, typename tosaOp>
@@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
59696021
m_TorchListOfConstantInts(kernelSizeInts)))
59706022
return rewriter.notifyMatchFailure(
59716023
op, "Non-const kernel_size for pooling op unsupported");
5972-
5973-
// Expand kernel size parameter to size 2 to be compatible with
5974-
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
5975-
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5976-
std::is_same<AtenOpT, AtenAvgPool1dOp>())
5977-
kernelSizeInts.push_back(1);
6024+
expandPoolParams(op, kernelSizeInts, 1);
59786025

59796026
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
59806027
return rewriter.notifyMatchFailure(
@@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters(
59866033
if (strideInts.empty()) {
59876034
strideInts.assign(kernelSizeInts);
59886035
} else {
5989-
// Expand stride parameter to size 2 to be compatible with
5990-
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
5991-
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
5992-
std::is_same<AtenOpT, AtenAvgPool1dOp>())
5993-
strideInts.push_back(1);
6036+
expandPoolParams(op, strideInts, 1);
59946037
}
59956038

59966039
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
59976040
return rewriter.notifyMatchFailure(
59986041
op, "Non-const padding factor for pooling op unsupported");
5999-
6000-
// Expand padding parameter to size 2 to be compatible with
6001-
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
6002-
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
6003-
std::is_same<AtenOpT, AtenAvgPool1dOp>())
6004-
paddingInts.push_back(0);
6042+
expandPoolParams(op, paddingInts, 0);
60056043

60066044
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
60076045
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
@@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
60336071
return rewriter.notifyMatchFailure(
60346072
op, "only support constant bool ceil_mode for pooling op");
60356073

6074+
expandPoolParams(op, dilationArray, 1);
60366075
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
60376076
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
60386077
ceilMode);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@
405405
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
406406
"Aten_TrilinearModuleVaryingRanks_basic",
407407
"Aten_TrilinearModuleZerodDimBug_basic",
408+
"AvgPool2dCHWModule_basic",
408409
"QuantizedReluInt32_basic",
409410
"QuantizedReluInt8_basic",
410411
"QuantizedReluUint8_basic",
@@ -528,6 +529,8 @@
528529
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
529530
"Aten_TrilinearModuleSumAllDims_basic",
530531
"Aten_TrilinearModuleSumdims_basic",
532+
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
533+
"AvgPool2dSingleIntTupleParamsModule_basic",
531534
}
532535

533536
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
@@ -952,6 +955,8 @@
952955
}
953956

954957
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
958+
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
959+
"AvgPool2dSingleIntTupleParamsModule_basic",
955960
"BatchNorm1DModule_basic",
956961
"BatchNorm2DModule_basic",
957962
"BatchNorm3DModule_basic",
@@ -2756,6 +2761,9 @@
27562761
"AtenTopKModule_basic",
27572762
"AtenTopKSmallestModule_basic",
27582763
"Aten_EmbeddingBagExample_basic",
2764+
"AvgPool2dCHWModule_basic",
2765+
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
2766+
"AvgPool2dSingleIntTupleParamsModule_basic",
27592767
"AvgPool2dWithoutPadModule_basic",
27602768
"BatchMlpLayerModule_basic",
27612769
"BincountMinlengthModule_basic",
@@ -3355,6 +3363,7 @@
33553363
"AtenSymConstrainRangeForSize_basic",
33563364
"AtenSymConstrainRange_basic",
33573365
"Aten_AssertScalar_basic",
3366+
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
33583367
"ScatterAddDynamicModule_basic",
33593368
"UniformModule_basic",
33603369
"UniformStaticShapeModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,84 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
14281428
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
14291429

14301430

1431+
class AvgPool2dCHWModule(torch.nn.Module):
1432+
def __init__(self):
1433+
super().__init__()
1434+
self.ap2d = torch.nn.AvgPool2d(
1435+
kernel_size=[6, 8],
1436+
stride=[2, 2],
1437+
)
1438+
1439+
@export
1440+
@annotate_args(
1441+
[
1442+
None,
1443+
([-1, -1, -1], torch.float32, True),
1444+
]
1445+
)
1446+
def forward(self, x):
1447+
return self.ap2d(x)
1448+
1449+
1450+
@register_test_case(module_factory=lambda: AvgPool2dCHWModule())
1451+
def AvgPool2dCHWModule_basic(module, tu: TestUtils):
1452+
module.forward(tu.rand(4, 20, 20, low=0.5, high=1.0))
1453+
1454+
1455+
class AvgPool2dSingleIntTupleParamsModule(torch.nn.Module):
1456+
def __init__(self):
1457+
super().__init__()
1458+
self.ap2d = torch.nn.AvgPool2d(
1459+
kernel_size=(6,),
1460+
stride=(2,),
1461+
padding=(1,),
1462+
count_include_pad=False,
1463+
)
1464+
1465+
@export
1466+
@annotate_args(
1467+
[
1468+
None,
1469+
([-1, -1, -1, -1], torch.float32, True),
1470+
]
1471+
)
1472+
def forward(self, x):
1473+
return self.ap2d(x)
1474+
1475+
1476+
@register_test_case(module_factory=lambda: AvgPool2dSingleIntTupleParamsModule())
1477+
def AvgPool2dSingleIntTupleParamsModule_basic(module, tu: TestUtils):
1478+
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
1479+
1480+
1481+
class AvgPool2dSingleIntTupleParamsIncludePadModule(torch.nn.Module):
1482+
def __init__(self):
1483+
super().__init__()
1484+
self.ap2d = torch.nn.AvgPool2d(
1485+
kernel_size=(6,),
1486+
stride=(2,),
1487+
padding=(1,),
1488+
count_include_pad=True,
1489+
)
1490+
1491+
@export
1492+
@annotate_args(
1493+
[
1494+
None,
1495+
([-1, -1, -1, -1], torch.float32, True),
1496+
]
1497+
)
1498+
def forward(self, x):
1499+
return self.ap2d(x)
1500+
1501+
1502+
@register_test_case(
1503+
module_factory=lambda: AvgPool2dSingleIntTupleParamsIncludePadModule()
1504+
)
1505+
def AvgPool2dSingleIntTupleParamsIncludePadModule_basic(module, tu: TestUtils):
1506+
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
1507+
1508+
14311509
# ==============================================================================
14321510

14331511

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,45 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc
22592259
return %3 : !torch.vtensor<[1,192,35,35],f32>
22602260
}
22612261

2262+
// -----
2263+
// CHECK-LABEL: func.func @avgPool2dCHWInput(
2264+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
2265+
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,56],f32> -> tensor<1x64x56xf32>
2266+
// CHECK: %[[NONE:.*]] = torch.constant.none
2267+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
2268+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
2269+
// CHECK: %[[C0:.*]] = torch.constant.int 0
2270+
// CHECK: %[[C1:.*]] = torch.constant.int 1
2271+
// CHECK: %[[C6:.*]] = torch.constant.int 6
2272+
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[C6]], %[[C6]] : (!torch.int, !torch.int) -> !torch.list<int>
2273+
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
2274+
// CHECK: %[[L3:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
2275+
// CHECK: %[[PERMS_IN:.*]] = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
2276+
// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]], %[[PERMS_IN]] : (tensor<1x64x56xf32>, tensor<3xi32>) -> tensor<64x56x1xf32>
2277+
// CHECK: %[[CONST_SHAPE_IN:.*]] = tosa.const_shape {value = dense<[1, 64, 56, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
2278+
// CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %[[TRANSPOSE_IN]], %[[CONST_SHAPE_IN]] : (tensor<64x56x1xf32>, !tosa.shape<4>) -> tensor<1x64x56x1xf32>
2279+
// CHECK: %[[POOL:.*]] = tosa.avg_pool2d %[[RESHAPE_IN]] {acc_type = f32, kernel = array<i64: 6, 6>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x64x56x1xf32>) -> tensor<1x59x51x1xf32>
2280+
// CHECK: %[[PERMS_OUT:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
2281+
// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]], %[[PERMS_OUT]] : (tensor<1x59x51x1xf32>, tensor<4xi32>) -> tensor<1x1x59x51xf32>
2282+
// CHECK: %[[CONST_SHAPE_OUT:.*]] = tosa.const_shape {value = dense<[1, 59, 51]> : tensor<3xindex>} : () -> !tosa.shape<3>
2283+
// CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[TRANSPOSE_OUT]], %[[CONST_SHAPE_OUT]] : (tensor<1x1x59x51xf32>, !tosa.shape<3>) -> tensor<1x59x51xf32>
2284+
// CHECK: %[[CAST:.*]] = tensor.cast %[[RESHAPE_OUT]] : tensor<1x59x51xf32> to tensor<1x59x51xf32>
2285+
// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x59x51xf32> -> !torch.vtensor<[1,59,51],f32>
2286+
// CHECK: return %[[TORCH]]
2287+
func.func @avgPool2dCHWInput(%arg0: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
2288+
%none = torch.constant.none
2289+
%false = torch.constant.bool false
2290+
%true = torch.constant.bool true
2291+
%int0 = torch.constant.int 0
2292+
%int1 = torch.constant.int 1
2293+
%int6 = torch.constant.int 6
2294+
%0 = torch.prim.ListConstruct %int6, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
2295+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2296+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
2297+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %true, %false, %none : !torch.vtensor<[1,64,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,59,51],f32>
2298+
return %3 : !torch.vtensor<[1,59,51],f32>
2299+
}
2300+
22622301
// -----
22632302

22642303
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {

0 commit comments

Comments
 (0)