Skip to content

Commit

Permalink
[Task] : Handle CHW input for avgpool2d.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Feb 24, 2025
1 parent a265d28 commit 2fb65f6
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 24 deletions.
87 changes: 63 additions & 24 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5764,6 +5764,35 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
: nhwcToNchw4DTransposeDims);
}

void
unsqueezeInputOutputFor2dPool(RankedTensorType inputTy, Value &input,
Type &outputTy, Location loc,
ConversionPatternRewriter &rewriter) const {
// 1d pool AtenOps mapped to TosaOp will already have the data in 4D format,
// here we can have 3D data only if the AtenOp itself is a 2d pool op with
// data in HWC format.

// Unsqueeze input tensor in HWC format to NHWC format to be
// compatible with tosa::AvgPool2dOp, batch is made explicitly 1.
SmallVector<int64_t> rank4Shape(inputTy.getShape());
assert(inputTy.getRank() == 3 &&
"Expected input to be atleast 3 dimensional.");
rank4Shape.insert(rank4Shape.begin(), 1);
input = rewriter.create<tosa::ReshapeOp>(
loc,
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
inputTy.getElementType()),
input, tosa::getTosaConstShape(rewriter, loc, rank4Shape));

// Unsqueeze output type
auto outRankedTy = cast<RankedTensorType>(outputTy);
assert(outRankedTy.getRank() == 3 &&
"Expected output rank to be same as input.");
SmallVector<int64_t> rank4ShapeOut(outRankedTy.getShape());
rank4ShapeOut.insert(rank4ShapeOut.begin(), 1);
outputTy = outRankedTy.clone(rank4ShapeOut);
}

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -5778,6 +5807,13 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Failed to process inputs for pooling");

// input has already been verified to be RankedTensorType
auto inputTy = cast<RankedTensorType>(input.getType());
if (inputTy.getRank() != 4) {
unsqueezeInputOutputFor2dPool(inputTy, input, outputTy, op->getLoc(),
rewriter);
}

Value pooledOutput;
static_assert(std::is_same<TosaOpT, tosa::MaxPool2dOp>::value ||
std::is_same<TosaOpT, tosa::AvgPool2dOp>::value,
Expand Down Expand Up @@ -5805,14 +5841,14 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
op, rewriter, pooledOutput);

Value result = transposedOutput;
auto resultTy = dyn_cast<TensorType>(
auto resultTy = cast<TensorType>(result.getType());
auto expectedResultTy = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
auto resultShape = resultTy.getShape();
auto resultElemTy = resultTy.getElementType();
if (resultTy.getRank() != expectedResultTy.getRank()) {
auto resultShape = expectedResultTy.getShape();
auto resultElemTy = expectedResultTy.getElementType();

result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
Expand All @@ -5823,7 +5859,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
makeShapeTorchCompatible(resultShape)));
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, expectedResultTy, result);

return success();
}
Expand Down Expand Up @@ -5851,7 +5887,7 @@ class ConvertAtenAdaptivePoolingOp
auto inputElemTy = inputTy.getElementType();

// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
if (inputRank != 4 && inputRank != 3)
return rewriter.notifyMatchFailure(
op, "NCHW->NHWC transpose requires 3D or 4D tensor");

Expand Down Expand Up @@ -5944,6 +5980,22 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
inputElemTy);
}

template <typename AtenOpT>
void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
int64_t val) {
// Expand pooling parameter (kernel, stride) to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
params.push_back(val);

if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (params.size() == 1)
params.push_back(params[0]);
}
}

// Checks the validity of pooling parameters and stores them in the respective
// vector. Also, gets the output type for the pooling op.
template <typename AtenOpT, typename tosaOp>
Expand All @@ -5969,12 +6021,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
m_TorchListOfConstantInts(kernelSizeInts)))
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for pooling op unsupported");

// Expand kernel size parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
kernelSizeInts.push_back(1);
expandPoolParams(op, kernelSizeInts, 1);

if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(
Expand All @@ -5986,22 +6033,13 @@ static LogicalResult getOutputTypeAndPoolingParameters(
if (strideInts.empty()) {
strideInts.assign(kernelSizeInts);
} else {
// Expand stride parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
strideInts.push_back(1);
expandPoolParams(op, strideInts, 1);
}

if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
return rewriter.notifyMatchFailure(
op, "Non-const padding factor for pooling op unsupported");

// Expand padding parameter to size 2 to be compatible with
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);
expandPoolParams(op, paddingInts, 0);

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
Expand Down Expand Up @@ -6033,6 +6071,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
return rewriter.notifyMatchFailure(
op, "only support constant bool ceil_mode for pooling op");

expandPoolParams(op, dilationArray, 1);
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
ceilMode);
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dCHWModule_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
Expand Down Expand Up @@ -528,6 +529,8 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
Expand Down Expand Up @@ -952,6 +955,8 @@
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
Expand Down Expand Up @@ -2756,6 +2761,9 @@
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool2dCHWModule_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"AvgPool2dWithoutPadModule_basic",
"BatchMlpLayerModule_basic",
"BincountMinlengthModule_basic",
Expand Down Expand Up @@ -3355,6 +3363,7 @@
"AtenSymConstrainRangeForSize_basic",
"AtenSymConstrainRange_basic",
"Aten_AssertScalar_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"ScatterAddDynamicModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
Expand Down
78 changes: 78 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,84 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


class AvgPool2dCHWModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=[6, 8],
stride=[2, 2],
)

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dCHWModule())
def AvgPool2dCHWModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 20, 20, low=0.5, high=1.0))


class AvgPool2dSingleIntTupleParamsModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,),
stride=(2,),
padding=(1,),
count_include_pad=False,
)

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dSingleIntTupleParamsModule())
def AvgPool2dSingleIntTupleParamsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


class AvgPool2dSingleIntTupleParamsIncludePadModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,),
stride=(2,),
padding=(1,),
count_include_pad=True,
)

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.ap2d(x)


@register_test_case(
module_factory=lambda: AvgPool2dSingleIntTupleParamsIncludePadModule()
)
def AvgPool2dSingleIntTupleParamsIncludePadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))


# ==============================================================================


Expand Down
39 changes: 39 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,45 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----
// CHECK-LABEL: func.func @avgPool2dCHWInput(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,56],f32> -> tensor<1x64x56xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C6:.*]] = torch.constant.int 6
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[C6]], %[[C6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[L3:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMS_IN:.*]] = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[TRANSPOSE_IN:.*]] = tosa.transpose %[[TENSOR]], %[[PERMS_IN]] : (tensor<1x64x56xf32>, tensor<3xi32>) -> tensor<64x56x1xf32>
// CHECK: %[[CONST_SHAPE_IN:.*]] = tosa.const_shape {value = dense<[1, 64, 56, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %[[TRANSPOSE_IN]], %[[CONST_SHAPE_IN]] : (tensor<64x56x1xf32>, !tosa.shape<4>) -> tensor<1x64x56x1xf32>
// CHECK: %[[POOL:.*]] = tosa.avg_pool2d %[[RESHAPE_IN]] {acc_type = f32, kernel = array<i64: 6, 6>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x64x56x1xf32>) -> tensor<1x59x51x1xf32>
// CHECK: %[[PERMS_OUT:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK: %[[TRANSPOSE_OUT:.*]] = tosa.transpose %[[POOL]], %[[PERMS_OUT]] : (tensor<1x59x51x1xf32>, tensor<4xi32>) -> tensor<1x1x59x51xf32>
// CHECK: %[[CONST_SHAPE_OUT:.*]] = tosa.const_shape {value = dense<[1, 59, 51]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[TRANSPOSE_OUT]], %[[CONST_SHAPE_OUT]] : (tensor<1x1x59x51xf32>, !tosa.shape<3>) -> tensor<1x59x51xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[RESHAPE_OUT]] : tensor<1x59x51xf32> to tensor<1x59x51xf32>
// CHECK: %[[TORCH:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x59x51xf32> -> !torch.vtensor<[1,59,51],f32>
// CHECK: return %[[TORCH]]
func.func @avgPool2dCHWInput(%arg0: !torch.vtensor<[1,64,56],f32>) -> !torch.vtensor<[1,59,51],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int6 = torch.constant.int 6
%0 = torch.prim.ListConstruct %int6, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %true, %false, %none : !torch.vtensor<[1,64,56],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,59,51],f32>
return %3 : !torch.vtensor<[1,59,51],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
Expand Down

0 comments on commit 2fb65f6

Please sign in to comment.