diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index d3518a3bc4..9f8aed3fa6 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include "mlir/Dialect/Traits.h" @@ -1921,8 +1922,8 @@ class StableHLOToTTIROpClampOpConversionPattern RankedTensorType outputType = mlir::cast( this->getTypeConverter()->convertType(srcOp.getResult().getType())); - if (std::optional minValue = getConstantValue(adaptor.getMin()), - maxValue = getConstantValue(adaptor.getMax()); + if (std::optional minValue = getConstantValue(srcOp.getMin()), + maxValue = getConstantValue(srcOp.getMax()); minValue && maxValue) { ttmlir::utils::replaceOpWithNewDPSOp( rewriter, srcOp, outputType, adaptor.getOperand(), @@ -1931,27 +1932,81 @@ class StableHLOToTTIROpClampOpConversionPattern return success(); } + mlir::Value min = + broadcastAttr(adaptor.getMin(), outputType, srcOp, rewriter); + mlir::Value max = + broadcastAttr(adaptor.getMax(), outputType, srcOp, rewriter); + ttir::MaximumOp maximumOp = ttmlir::utils::createDPSOp( - rewriter, srcOp->getLoc(), outputType, adaptor.getMin(), - adaptor.getOperand()); + rewriter, srcOp->getLoc(), outputType, min, adaptor.getOperand()); ttmlir::utils::replaceOpWithNewDPSOp( - rewriter, srcOp, outputType, maximumOp.getResult(0), adaptor.getMax()); + rewriter, srcOp, outputType, maximumOp.getResult(0), max); return success(); } private: std::optional getConstantValue(Value value) const { - if (auto constantOp = value.getDefiningOp()) { + Operation *op = value.getDefiningOp(); + while (op && + (isa(op) || + isa(op) || isa(op))) { + op = op->getOperand(0).getDefiningOp(); + } + if (!op) { + return std::nullopt; + } + + if (auto constantOp = mlir::dyn_cast(op)) { auto attr = constantOp.getValueAttr(); if (!attr.isSplat()) { - return {}; + return std::nullopt; } - return attr.getElementType().isInteger() - ? static_cast(attr.getSplatValue()) - : attr.getSplatValue(); + mlir::Type elementType = attr.getElementType(); + mlir::APFloat fillValue(mlir::APFloat::IEEEsingle()); + if (isa(elementType)) { + fillValue.convertFromAPInt(attr.getSplatValue(), + attr.getElementType().isSignedInteger(), + llvm::RoundingMode::TowardZero); + return fillValue.convertToFloat(); + } + if (isa(elementType)) { + return static_cast( + attr.getSplatValue().convertToDouble()); + } + assert(false && "Unsupported data type."); } - return {}; + return std::nullopt; + } + + mlir::Value broadcastAttr(mlir::Value input, RankedTensorType desiredType, + mlir::stablehlo::ClampOp srcOp, + ConversionPatternRewriter &rewriter) const { + auto inputType = mlir::cast(input.getType()); + if (inputType.getShape() == desiredType.getShape()) { + return input; + } + + SmallVector unsqueezeShape(desiredType.getRank(), 1); + for (int64_t i = 0; i < inputType.getRank(); i++) { + unsqueezeShape[i] = inputType.getDimSize(i); + } + SmallVector reshapeDim(unsqueezeShape.begin(), + unsqueezeShape.end()); + + auto reshapeDimAttr = rewriter.getI32ArrayAttr(reshapeDim); + ttir::ReshapeOp reshapeOp = ttmlir::utils::createDPSOp( + rewriter, srcOp.getLoc(), unsqueezeShape, desiredType.getElementType(), + desiredType.getEncoding(), input, reshapeDimAttr); + + ::llvm::ArrayRef inputShape = unsqueezeShape; + ::llvm::ArrayRef outputShape = desiredType.getShape(); + + SmallVector broadcastShape = + ttmlir::utils::getBroadcastDimensions(inputShape, outputShape); + + return ttmlir::utils::createDPSOp( + rewriter, srcOp->getLoc(), desiredType, reshapeOp, broadcastShape); } }; } // namespace diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir index d46b00e6a6..92b7316828 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir @@ -1,7 +1,8 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -module @jit_transpose attributes {} { +module @jit_clamp attributes {} { func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-LABEL: func.func public @test_clamp_constant %cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] @@ -12,7 +13,38 @@ module @jit_transpose attributes {} { return %0 : tensor<4xf32> } + func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> { + // CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape + %cst = arith.constant dense<3.0> : tensor<1xf64> + %cst_0 = arith.constant dense<6> : tensor<1xi64> + %0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16> + %1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor + %2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16> + %3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x16xbf16>]] + // CHECK: "ttir.clamp"(%arg0, %[[EMPTY]]) + // CHECK-SAME: max = 6.000000e+00 : f32, min = 3.000000e+00 : f32 + // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %4 = stablehlo.clamp %1, %arg0, %3 : (tensor, tensor<1x16xbf16>, tensor) -> tensor<1x16xbf16> + return %4 : tensor<1x16xbf16> + } + + func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) { + // CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast + %cst = stablehlo.constant dense<2.000000e+00> : tensor + %cst_0 = stablehlo.constant dense<5.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1x32xbf16> + %1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<1x32xbf16> + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x32xbf16>]] + // CHECK: "ttir.clamp"(%arg0, %[[EMPTY]]) + // CHECK-SAME: max = 5.000000e+00 : f32, min = 2.000000e+00 : f32 + // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16> + return %2 : tensor<1x32xbf16> + } + func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-LABEL: func.func public @test_clamp_tensor // CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] // CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]]) // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] @@ -23,4 +55,35 @@ module @jit_transpose attributes {} { // CHECK: return %[[MIN]] : [[TENSOR]] return %0 : tensor<4xf32> } + + func.func public @test_clamp_tensor_constant(%arg0: tensor<1x16xbf16>, %arg1: tensor) -> tensor<1x16xbf16> { + // CHECK-LABEL: func.func public @test_clamp_tensor_constant( + // CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %cst = arith.constant dense<3.0> : tensor<1xf64> + // CHECK: %[[CAST:[0-9]+]] = "ttir.typecast"(%[[CONSTANT]], + // CHECK-SAME: (tensor<1xf32>, tensor<1xbf16>) -> tensor<1xbf16> + %0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16> + // CHECK: %[[RESHAPE0:[0-9]+]] = "ttir.reshape"(%[[CAST]], + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor + // CHECK: %[[RESHAPE1:[0-9]+]] = "ttir.reshape"(%[[RESHAPE0]], + // CHECK-SAME: shape = [1 : i32, 1 : i32] + // CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16> + // CHECK: %[[MIN:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE1]] + // CHECK-SAME: {broadcast_dimensions = array} + // CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16> + // CHECK: %[[RESHAPE2:[0-9]+]] = "ttir.reshape"(%arg1, + // CHECK-SAME: <{shape = [1 : i32, 1 : i32]}> + // CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16> + // CHECK: %[[MAX:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE2]], + // CHECK-SAME: <{broadcast_dimensions = array}> + // CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16> + // CHECK: %[[ARG:[0-9]+]] = "ttir.maximum"(%[[MIN]], %arg0, + // CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16> + // CHECK: "ttir.minimum"(%[[ARG]], %[[MAX]], + // CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16> + %2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor, tensor<1x16xbf16>, tensor) -> tensor<1x16xbf16> + return %2 : tensor<1x16xbf16> + } } diff --git a/test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir b/test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir index 1beab078d4..5b8b85067e 100644 --- a/test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir @@ -1,14 +1,14 @@ // REQUIRES: stablehlo // RUN: rm -rf %t.ttnn // RUN: rm -rf %t.mlir -// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s \ +// RUN: --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s module @jit_transpose attributes {} { func.func public @test_clamp_constant(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK-LABEL: func.func public @test_clamp_constant + // CHECK-LABEL: func.func public @test_clamp_constant( // CHECK: ttnn.clamp // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} // CHECK-SAME: tensor<64x128xf32, @@ -19,8 +19,38 @@ module @jit_transpose attributes {} { return %0 : tensor<64x128xf32> } + func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> { + // CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape + %cst = arith.constant dense<3.0> : tensor<1xf64> + %cst_0 = arith.constant dense<6> : tensor<1xi64> + %0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16> + %1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor + %2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16> + %3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor + // CHECK: ttnn.clamp + // CHECK-SAME: {max = 6.000000e+00 : f32, min = 3.000000e+00 : f32} + // CHECK-SAME: tensor<1x16xbf16, + // CHECK-SAME: -> tensor<1x16xbf16, + %4 = stablehlo.clamp %1, %arg0, %3 : (tensor, tensor<1x16xbf16>, tensor) -> tensor<1x16xbf16> + return %4 : tensor<1x16xbf16> + } + + func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) { + // CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast + %cst = stablehlo.constant dense<2.000000e+00> : tensor + %cst_0 = stablehlo.constant dense<5.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1x32xbf16> + %1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<1x32xbf16> + // CHECK: ttnn.clamp + // CHECK-SAME: {max = 5.000000e+00 : f32, min = 2.000000e+00 : f32} + // CHECK-SAME: tensor<1x32xbf16, + // CHECK-SAME: -> tensor<1x32xbf16, + %2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16> + return %2 : tensor<1x32xbf16> + } + func.func public @test_clamp_tensor(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>, %arg2: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK-LABEL: func.func public @test_clamp_tensor + // CHECK-LABEL: func.func public @test_clamp_tensor( // CHECK: %[[MAX:.*]] = "ttnn.maximum" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, @@ -32,4 +62,21 @@ module @jit_transpose attributes {} { %0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<64x128xf32> return %0 : tensor<64x128xf32> } + + func.func public @test_clamp_tensor_constant(%arg0: tensor<32x32xbf16>, %arg1: tensor) -> tensor<32x32xbf16> { + // CHECK-LABEL: func.func public @test_clamp_tensor_constant( + %cst = arith.constant dense<3.0> : tensor<1xf64> + %0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16> + %1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor + // CHECK: %[[MAX:.*]] = "ttnn.maximum" + // CHECK-SAME: tensor<32x32xbf16, + // CHECK-SAME: tensor<32x32xbf16, + // CHECK-SAME: -> tensor<32x32xbf16, + // CHECK: "ttnn.minimum"(%[[MAX]] + // CHECK-SAME: tensor<32x32xbf16, + // CHECK-SAME: tensor<32x32xbf16, + // CHECK-SAME: -> tensor<32x32xbf16, + %2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor, tensor<32x32xbf16>, tensor) -> tensor<32x32xbf16> + return %2 : tensor<32x32xbf16> + } }