diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 167c473558..fe9a5cb865 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -831,69 +831,99 @@ struct PoolingToPool2dPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - std::vector getIndicesOfSpatialDims(ttir::PoolingOp op) const { - std::vector spatialDims; - for (int64_t i = 0; - i < static_cast(op.getWindowDimensions().size()); i++) { - if (op.getWindowDimensions()[i] > 1) { - spatialDims.push_back(i); + LogicalResult + matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector spatialDimIndices = + getIndicesOfElementsLargerThanOne(op.getWindowDimensions()); + size_t numSpatialDimIndices = spatialDimIndices.size(); + if (numSpatialDimIndices > 2) { + return rewriter.notifyMatchFailure( + op, "No decompositions for a pooling op with " + + std::to_string(numSpatialDimIndices) + " spatial dimensions"); + } + + LogicalResult legalityResult = + canDecompose2DPoolingOp(op, rewriter, spatialDimIndices); + if (!legalityResult.succeeded()) { + return legalityResult; + } + + switch (op.getPoolingMethod()) { + case ttir::PoolingMethod::Max: { + rewritePool2d(op, adaptor, rewriter, + spatialDimIndices); + return success(); + } + default: { + return rewriter.notifyMatchFailure( + op, "Failed to match pooling method: " + + stringifyPoolingMethod(op.getPoolingMethod())); + } + } + } + +private: + llvm::SmallVector + getIndicesOfElementsLargerThanOne(llvm::ArrayRef input) const { + llvm::SmallVector result; + for (size_t i = 0; i < input.size(); i++) { + if (input[i] > 1) { + result.push_back(i); } } - return spatialDims; + return result; } - LogicalResult canDecompose2DPoolingOp(ttir::PoolingOp op) const { + LogicalResult + canDecompose2DPoolingOp(ttir::PoolingOp op, + ConversionPatternRewriter &rewriter, + llvm::SmallVector spatialDimIndices) const { // Window dimensions must be 4 in length if (op.getWindowDimensions().size() != 4) { - return failure(); + return rewriter.notifyMatchFailure( + op, "Polling 2D op is only supported for 4D tensor."); } // Window strides must be 4 in length if (op.getWindowStrides().size() != 4) { - return failure(); + return rewriter.notifyMatchFailure( + op, "Polling 2D op is only supported for 4D tensor."); } // Operand rank(s) must be 4 for (Value operand : op.getInputs()) { auto operandType = mlir::cast(operand.getType()); if (operandType.getRank() != 4) { - return failure(); - } - } - - // Exactly two of the window dimensions must be greater than 1 - std::vector trueWindowDimensionsIndices = - getIndicesOfSpatialDims(op); - - if (trueWindowDimensionsIndices.size() != 2) { - return failure(); - } - - // Exactly two of the window strides must be greater than 1 - std::vector trueWindowStrideIndices; - for (int64_t i = 0; i < static_cast(op.getWindowStrides().size()); - i++) { - if (op.getWindowStrides()[i] > 1) { - trueWindowStrideIndices.push_back(i); + return rewriter.notifyMatchFailure( + op, "Polling 2D op is only supported for 4D tensor."); } } - if (trueWindowStrideIndices.size() != 2) { - return failure(); + // Window dimensions will have two or less than two non 1 elements; + // representing the kernel size for max pooling operation. + size_t numSpatialDimIndices = spatialDimIndices.size(); + if (numSpatialDimIndices > 2) { + return rewriter.notifyMatchFailure( + op, "Rank of kernel_size for pooling 2D op is greater than 2."); } - // The indices of the true window dimensions and strides must be the same - if ((trueWindowDimensionsIndices[0] != trueWindowStrideIndices[0] || - trueWindowDimensionsIndices[1] != trueWindowStrideIndices[1]) && - (trueWindowDimensionsIndices[0] != trueWindowStrideIndices[1] || - trueWindowDimensionsIndices[1] != trueWindowStrideIndices[0])) { - return failure(); + // Window strides will have two or less than two non 1 elements; + // representing the strides for max pooling operation. + llvm::SmallVector trueWindowStrideIndices = + getIndicesOfElementsLargerThanOne(op.getWindowStrides()); + size_t windowStrideSize = trueWindowStrideIndices.size(); + if (windowStrideSize > 2) { + return rewriter.notifyMatchFailure( + op, "Rank of strides for pooling 2D is greater than 2."); } // Padding must be 8 in length if (op.getPadding().size() != 8) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "Number of elements in padding does not match with pooling 2D op."); } return success(); @@ -901,7 +931,8 @@ struct PoolingToPool2dPattern : public OpConversionPattern { template void rewritePool2d(ttir::PoolingOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + llvm::SmallVector spatialDimIndices) const { const int64_t SPATIAL_H = -3; const int64_t SPATIAL_W = -2; @@ -922,11 +953,20 @@ struct PoolingToPool2dPattern : public OpConversionPattern { } } - std::vector spatialDims = getIndicesOfSpatialDims(op); + int64_t numWinDims = op.getWindowDimensions().size(); + // Using default indices for channel first tensor if window dimension + // attribute does not contain two non 1 elements for kernel size. + // [TODO] (mmanzoor) Add an option to distingush channel first vs channel + // last and support channel last default indices. + // https://github.com/tenstorrent/tt-mlir/issues/2237 + spatialDimIndices = + (spatialDimIndices.size() == 2) + ? spatialDimIndices + : llvm::SmallVector({numWinDims - 2, numWinDims - 1}); std::vector currentLayout(inputType.getRank(), NON_SPATIAL); - currentLayout[spatialDims[0]] = SPATIAL_H; - currentLayout[spatialDims[1]] = SPATIAL_W; + currentLayout[spatialDimIndices[0]] = SPATIAL_H; + currentLayout[spatialDimIndices[1]] = SPATIAL_W; nonSpatialCount = 0; for (int64_t i = 0; i < static_cast(currentLayout.size()); i++) { @@ -941,30 +981,30 @@ struct PoolingToPool2dPattern : public OpConversionPattern { auto inverseOfPermutation = ttmlir::utils::inversePermutation(permutation); auto kernelHeightAttr = rewriter.getSI32IntegerAttr( - static_cast(op.getWindowDimensions()[spatialDims[0]])); + static_cast(op.getWindowDimensions()[spatialDimIndices[0]])); auto kernelWidthAttr = rewriter.getSI32IntegerAttr( - static_cast(op.getWindowDimensions()[spatialDims[1]])); + static_cast(op.getWindowDimensions()[spatialDimIndices[1]])); auto strideHeightAttr = rewriter.getSI32IntegerAttr( - static_cast(op.getWindowStrides()[spatialDims[0]])); + static_cast(op.getWindowStrides()[spatialDimIndices[0]])); auto strideWidthAttr = rewriter.getSI32IntegerAttr( - static_cast(op.getWindowStrides()[spatialDims[1]])); + static_cast(op.getWindowStrides()[spatialDimIndices[1]])); auto dilationHeightAttr = rewriter.getSI32IntegerAttr( - adaptor.getWindowDilations()[spatialDims[0]]); + adaptor.getWindowDilations()[spatialDimIndices[0]]); auto dilationWidthAttr = rewriter.getSI32IntegerAttr( - adaptor.getWindowDilations()[spatialDims[1]]); + adaptor.getWindowDilations()[spatialDimIndices[1]]); auto ceilModeAttr = rewriter.getBoolAttr(false); auto paddingTopAttr = - rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0]]); - auto paddingBottomAttr = - rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0] + 1]); + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDimIndices[0]]); + auto paddingBottomAttr = rewriter.getSI32IntegerAttr( + op.getPadding()[2 * spatialDimIndices[0] + 1]); auto paddingLeftAttr = - rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1]]); - auto paddingRightAttr = - rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]); + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDimIndices[1]]); + auto paddingRightAttr = rewriter.getSI32IntegerAttr( + op.getPadding()[2 * spatialDimIndices[1] + 1]); llvm::SmallVector outputs; for (Value input : adaptor.getInputs()) { @@ -999,45 +1039,6 @@ struct PoolingToPool2dPattern : public OpConversionPattern { rewriter.replaceOp(op, outputs); } - - uint32_t getNumSpatialDims(ttir::PoolingOp op) const { - uint32_t numSpatialDims = 0; - for (int64_t dim : op.getWindowDimensions()) { - if (dim > 1) { - numSpatialDims++; - } - } - return numSpatialDims; - } - - LogicalResult - matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - uint32_t numSpatialDims = getNumSpatialDims(op); - if (numSpatialDims == 2) { - if (failed(canDecompose2DPoolingOp(op))) { - return rewriter.notifyMatchFailure( - op, "2D pooling op with the given attributes is not supported " - "currently"); - } - - switch (op.getPoolingMethod()) { - case ttir::PoolingMethod::Max: { - rewritePool2d(op, adaptor, rewriter); - return success(); - } - default: { - return rewriter.notifyMatchFailure( - op, "Failed to match pooling method: " + - stringifyPoolingMethod(op.getPoolingMethod())); - } - } - } - return rewriter.notifyMatchFailure( - op, "No decompositions for a pooling op with " + - std::to_string(numSpatialDims) + " spatial dimensions"); - } }; } // namespace diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 1fe4d9b69e..0ef642bbed 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -40,7 +40,7 @@ struct TTIRToTTIRDecompositionPass // func.func and // func.call as legal ops target.addLegalDialect(); // This contains the "module" op - // which is necesarry + // which is necessary target.addLegalOp(); // DPS operands are create with // tensor::EmptyOp diff --git a/test/ttmlir/Decomposition/TTIR/pooling/max_pool2d.mlir b/test/ttmlir/Decomposition/TTIR/pooling/max_pool2d.mlir new file mode 100644 index 0000000000..252c29bba0 --- /dev/null +++ b/test/ttmlir/Decomposition/TTIR/pooling/max_pool2d.mlir @@ -0,0 +1,103 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s + +module attributes {} { + // Kernel size = 1; stride = 1 + func.func @test_maxpool2d_kernel_1x1_stride_1x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x28x28xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_1x1_stride_1x1( + %0 = tensor.empty() : tensor<1x192x28x28xbf16> + // CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0 + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>) + // CHECK-SAME: -> tensor<1x28x28x192xbf16> + // CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]], + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 1 : si32, + // CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32, + // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32 + // CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x28x28x192xbf16>) + // CHECK-SAME: -> tensor<1x28x28x192xbf16> + %1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array, operandSegmentSizes = array, padding = array, pooling_method = #ttir, window_dilations = array, window_dimensions = array, window_strides = array}> : (tensor<1x192x28x28xbf16>, tensor<1x192x28x28xbf16>) -> tensor<1x192x28x28xbf16> + // CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]], + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x192x28x28xbf16>) + // CHECK-SAME: -> tensor<1x192x28x28xbf16> + // CHECK: return %[[RET]] : tensor<1x192x28x28xbf16> + return %1 : tensor<1x192x28x28xbf16> + } + + // Kernel size = 3; stride = 1 + func.func @test_maxpool2d_kernel_3x3_stride_1x1(%arg0: tensor<1x256x28x28xbf16>) -> tensor<1x256x28x28xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_3x3_stride_1x1( + %0 = tensor.empty() : tensor<1x256x28x28xbf16> + // CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0 + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x256x28x28xbf16>, tensor<1x28x28x256xbf16>) + // CHECK-SAME: -> tensor<1x28x28x256xbf16> + // CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]], + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: kernel_height = 3 : si32, kernel_width = 3 : si32, + // CHECK-SAME: padding_bottom = 1 : si32, padding_left = 1 : si32, padding_right = 1 : si32, padding_top = 1 : si32, + // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32 + // CHECK-SAME: (tensor<1x28x28x256xbf16>, tensor<1x28x28x256xbf16>) + // CHECK-SAME: -> tensor<1x28x28x256xbf16> + %1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array, operandSegmentSizes = array, padding = array, pooling_method = #ttir, window_dilations = array, window_dimensions = array, window_strides = array}> : (tensor<1x256x28x28xbf16>, tensor<1x256x28x28xbf16>) -> tensor<1x256x28x28xbf16> + // CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]], + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x28x28x256xbf16>, tensor<1x256x28x28xbf16>) + // CHECK-SAME: -> tensor<1x256x28x28xbf16> + // CHECK: return %[[RET]] : tensor<1x256x28x28xbf16> + return %1 : tensor<1x256x28x28xbf16> + } + + // Kernel size = (2, 1); stride = 1 + func.func @test_maxpool2d_kernel_2x1_stride_1x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x27x28xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_2x1_stride_1x1( + %0 = tensor.empty() : tensor<1x192x27x28xbf16> + // CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0 + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>) + // CHECK-SAME: -> tensor<1x28x28x192xbf16> + // CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]], + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: kernel_height = 2 : si32, kernel_width = 1 : si32, + // CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32, + // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32 + // CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x27x28x192xbf16>) + // CHECK-SAME: -> tensor<1x27x28x192xbf16> + %1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array, operandSegmentSizes = array, padding = array, pooling_method = #ttir, window_dilations = array, window_dimensions = array, window_strides = array}> : (tensor<1x192x28x28xbf16>, tensor<1x192x27x28xbf16>) -> tensor<1x192x27x28xbf16> + // CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]], + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x27x28x192xbf16>, tensor<1x192x27x28xbf16>) + // CHECK-SAME: -> tensor<1x192x27x28xbf16> + // CHECK: return %[[RET]] : tensor<1x192x27x28xbf16> + return %1 : tensor<1x192x27x28xbf16> + } + + // Kernel size = (1, 2); stride = (3, 1) + func.func @test_maxpool2d_kernel_1x2_stride_3x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x10x27xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_1x2_stride_3x1( + %0 = tensor.empty() : tensor<1x192x10x27xbf16> + // CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0 + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>) + // CHECK-SAME: -> tensor<1x28x28x192xbf16> + // CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]], + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 2 : si32, + // CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32, + // CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32 + // CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x10x27x192xbf16>) + // CHECK-SAME: -> tensor<1x10x27x192xbf16> + %1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array, operandSegmentSizes = array, padding = array, pooling_method = #ttir, window_dilations = array, window_dimensions = array, window_strides = array}> : (tensor<1x192x28x28xbf16>, tensor<1x192x10x27xbf16>) -> tensor<1x192x10x27xbf16> + // CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]], + // CHECK-SAME: permutation = array + // CHECK-SAME: (tensor<1x10x27x192xbf16>, tensor<1x192x10x27xbf16>) + // CHECK-SAME: -> tensor<1x192x10x27xbf16> + // CHECK: return %[[RET]] : tensor<1x192x10x27xbf16> + return %1 : tensor<1x192x10x27xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir index e20bb0375a..7b2c50f967 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/maxpool2d_op.mlir @@ -1,19 +1,132 @@ // REQUIRES: stablehlo // RUN: rm -rf %t.ttnn // RUN: rm -rf %t.mlir -// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s \ +// RUN: --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s -// UNSUPPORTED: true -func.func public @test_maxpool2d(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { - %0 = stablehlo.constant dense<0xFF80> : tensor - %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor - %2 = "stablehlo.reduce_window"(%arg0, %1) <{padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array, window_strides = array}> ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %3 = stablehlo.maximum %arg2, %arg3 : tensor - stablehlo.return %3 : tensor - }) : (tensor<1x128x128x32xbf16>, tensor) -> tensor<1x64x64x32xbf16> - return %2 : tensor<1x64x64x32xbf16> +module @max_pool2d attributes {} { + // Kernel size = 3; Stride = 3; Padding = 1 + func.func @test_maxpool2d_kernel_3x3_stride_3x3_padding_1(%arg0: tensor<1x128x32x32xbf16>) -> tensor<1x128x11x11xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_3x3_stride_3x3_padding_1( + %cst = stablehlo.constant dense<0xFF80> : tensor + // CHECK: %[[PERMUTE:[0-9]+]] = "ttnn.permute"(%arg0) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x128x32x32xbf16 + // CHECK-SAME: -> tensor<1x32x32x128xbf16 + // CHECK: "ttnn.reshape"(%[[PERMUTE]]) + // CHECK-SAME: shape = [1 : i32, 1 : i32, 1024 : i32, 128 : i32] + // CHECK-SAME: tensor<1x32x32x128xbf16 + // CHECK-SAME: -> tensor<1x1x1024x128xbf16 + // CHECK: "ttnn.max_pool2d" + // CHECK-SAME: batch_size = 1 : si32, + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: channels = 128 : si32, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: input_height = 32 : si32, input_width = 32 : si32, + // CHECK-SAME: kernel_height = 3 : si32, kernel_width = 3 : si32, + // CHECK-SAME: padding_height = 1 : si32, padding_width = 1 : si32, + // CHECK-SAME: stride_height = 3 : si32, stride_width = 3 : si32} + // CHECK-SAME: tensor<1x1x1024x128xbf16 + // CHECK-SAME: tensor<1x1x121x128xbf16 + // CHECK-SAME: -> tensor<1x1x121x128xbf16 + %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor<1x128x32x32xbf16>, tensor) -> tensor<1x128x11x11xbf16> + // CHECK: %[[RESHAPE:[0-9]+]] = "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 11 : i32, 11 : i32, 128 : i32] + // CHECK-SAME: tensor<1x1x121x128xbf16 + // CHECK-SAME: -> tensor<1x11x11x128xbf16 + // CHECK: %[[RET:[0-9]+]] = "ttnn.permute"(%[[RESHAPE]]) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x11x11x128xbf16 + // CHECK-SAME: -> tensor<1x128x11x11xbf16 + // CHECK: return %[[RET]] : tensor<1x128x11x11xbf16 + return %0 : tensor<1x128x11x11xbf16> + } + + // Kernel size = 1; Stride = 1; Padding = 0 + func.func @test_maxpool2d_kernel_1x1_stride_1x1_padding_0(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x28x28xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_1x1_stride_1x1_padding_0( + %cst = stablehlo.constant dense<0xFF80> : tensor + // CHECK: %[[PERMUTE:[0-9]+]] = "ttnn.permute"(%arg0) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x192x28x28xbf16 + // CHECK-SAME: -> tensor<1x28x28x192xbf16 + // CHECK: "ttnn.reshape"(%[[PERMUTE]]) + // CHECK-SAME: shape = [1 : i32, 1 : i32, 784 : i32, 192 : i32] + // CHECK-SAME: tensor<1x28x28x192xbf16 + // CHECK-SAME: -> tensor<1x1x784x192xbf16 + // CHECK: "ttnn.max_pool2d" + // CHECK-SAME: batch_size = 1 : si32, + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: channels = 192 : si32, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32, + // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 1 : si32, + // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, + // CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32} + // CHECK-SAME: tensor<1x1x784x192xbf16 + // CHECK-SAME: tensor<1x1x784x192xbf16 + // CHECK-SAME: -> tensor<1x1x784x192xbf16 + %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor<1x192x28x28xbf16>, tensor) -> tensor<1x192x28x28xbf16> + // CHECK: %[[RESHAPE:[0-9]+]] = "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 28 : i32, 28 : i32, 192 : i32] + // CHECK-SAME: tensor<1x1x784x192xbf16 + // CHECK-SAME: -> tensor<1x28x28x192xbf16 + // CHECK: %[[RET:[0-9]+]] = "ttnn.permute"(%[[RESHAPE]]) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x28x28x192xbf16 + // CHECK-SAME: -> tensor<1x192x28x28xbf16 + // CHECK: return %[[RET]] : tensor<1x192x28x28xbf16 + return %0 : tensor<1x192x28x28xbf16> + } + + // Kernel size = (1, 2) ; Stride = (3, 1); Padding = 0 + func.func @test_maxpool2d_kernel_1x2_stride_3x1_padding_0(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x10x27xbf16> { + // CHECK-LABEL: func.func @test_maxpool2d_kernel_1x2_stride_3x1_padding_0( + %cst = stablehlo.constant dense<0xFF80> : tensor + // CHECK: %[[PERMUTE:[0-9]+]] = "ttnn.permute"(%arg0) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x192x28x28xbf16 + // CHECK-SAME: -> tensor<1x28x28x192xbf16 + // CHECK: "ttnn.reshape"(%[[PERMUTE]]) + // CHECK-SAME: shape = [1 : i32, 1 : i32, 784 : i32, 192 : i32] + // CHECK-SAME: tensor<1x28x28x192xbf16 + // CHECK-SAME: -> tensor<1x1x784x192xbf16 + // CHECK: "ttnn.max_pool2d" + // CHECK-SAME: batch_size = 1 : si32, + // CHECK-SAME: ceil_mode = false, + // CHECK-SAME: channels = 192 : si32, + // CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32, + // CHECK-SAME: input_height = 28 : si32, input_width = 28 : si32, + // CHECK-SAME: kernel_height = 1 : si32, kernel_width = 2 : si32, + // CHECK-SAME: padding_height = 0 : si32, padding_width = 0 : si32, + // CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32} + // CHECK-SAME: tensor<1x1x784x192xbf16 + // CHECK-SAME: tensor<1x1x270x192xbf16 + // CHECK-SAME: -> tensor<1x1x270x192xbf16 + %0 = "stablehlo.reduce_window"(%arg0, %cst) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %1 : tensor + }) : (tensor<1x192x28x28xbf16>, tensor) -> tensor<1x192x10x27xbf16> + // CHECK: %[[RESHAPE:[0-9]+]] = "ttnn.reshape" + // CHECK-SAME: shape = [1 : i32, 10 : i32, 27 : i32, 192 : i32] + // CHECK-SAME: tensor<1x1x270x192xbf16 + // CHECK-SAME: -> tensor<1x10x27x192xbf16 + // CHECK: %[[RET:[0-9]+]] = "ttnn.permute"(%[[RESHAPE]]) + // CHECK-SAME: permutation = array + // CHECK-SAME: tensor<1x10x27x192xbf16 + // CHECK-SAME: -> tensor<1x192x10x27xbf16 + // CHECK: return %[[RET]] : tensor<1x192x10x27xbf16 + return %0 : tensor<1x192x10x27xbf16> + } }