Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine indirectly defined constant for clamp op #2268

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1905,8 +1905,8 @@ class StableHLOToTTIROpClampOpConversionPattern
RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));

if (std::optional<float> minValue = getConstantValue(adaptor.getMin()),
maxValue = getConstantValue(adaptor.getMax());
if (std::optional<float> minValue = getConstantValue(srcOp.getMin()),
maxValue = getConstantValue(srcOp.getMax());
minValue && maxValue) {
ttmlir::utils::replaceOpWithNewDPSOp<ttir::ClampOp>(
rewriter, srcOp, outputType, adaptor.getOperand(),
Expand All @@ -1915,27 +1915,81 @@ class StableHLOToTTIROpClampOpConversionPattern
return success();
}

mlir::Value min =
broadcastAttr(adaptor.getMin(), outputType, srcOp, rewriter);
mlir::Value max =
broadcastAttr(adaptor.getMax(), outputType, srcOp, rewriter);

ttir::MaximumOp maximumOp = ttmlir::utils::createDPSOp<ttir::MaximumOp>(
rewriter, srcOp->getLoc(), outputType, adaptor.getMin(),
adaptor.getOperand());
rewriter, srcOp->getLoc(), outputType, min, adaptor.getOperand());
ttmlir::utils::replaceOpWithNewDPSOp<ttir::MinimumOp>(
rewriter, srcOp, outputType, maximumOp.getResult(0), adaptor.getMax());
rewriter, srcOp, outputType, maximumOp.getResult(0), max);

return success();
}

private:
std::optional<float> getConstantValue(Value value) const {
Copy link
Contributor

@mrakitaTT mrakitaTT Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you actually already have the logic here to extract the constant values to use. I would then avoid delaying the proper solution for later and just do it right now since it doesn't require much more work (to add two separate clamp ops to ttir and ttnn, and map here to the one with constant value attributes, and remove this decomposition logic). One model being blocked in tt-torch doesn't seem like that big of priority to me, please let me know if you disagree.

if (auto constantOp = value.getDefiningOp<ttir::ConstantOp>()) {
Operation *op = value.getDefiningOp();
while (op &&
(isa<stablehlo::BroadcastInDimOp>(op) ||
isa<stablehlo::ReshapeOp>(op) || isa<stablehlo::ConvertOp>(op))) {
op = op->getOperand(0).getDefiningOp();
}
if (!op) {
return std::nullopt;
}

if (auto constantOp = mlir::dyn_cast<stablehlo::ConstantOp>(op)) {
auto attr = constantOp.getValueAttr();
if (!attr.isSplat()) {
return {};
return std::nullopt;
}
return attr.getElementType().isInteger()
? static_cast<float>(attr.getSplatValue<int>())
: attr.getSplatValue<float>();
mlir::Type elementType = attr.getElementType();
mlir::APFloat fillValue(mlir::APFloat::IEEEsingle());
if (isa<IntegerType>(elementType)) {
fillValue.convertFromAPInt(attr.getSplatValue<llvm::APInt>(),
attr.getElementType().isSignedInteger(),
llvm::RoundingMode::TowardZero);
return fillValue.convertToFloat();
}
if (isa<FloatType>(elementType)) {
return static_cast<float>(
attr.getSplatValue<mlir::APFloat>().convertToDouble());
}
assert(false && "Unsupported data type.");
}
return {};
return std::nullopt;
}

mlir::Value broadcastAttr(mlir::Value input, RankedTensorType desiredType,
mlir::stablehlo::ClampOp srcOp,
ConversionPatternRewriter &rewriter) const {
auto inputType = mlir::cast<RankedTensorType>(input.getType());
if (inputType.getShape() == desiredType.getShape()) {
return input;
}

SmallVector<int64_t> unsqueezeShape(desiredType.getRank(), 1);
for (int64_t i = 0; i < inputType.getRank(); i++) {
unsqueezeShape[i] = inputType.getDimSize(i);
}
SmallVector<int32_t> reshapeDim(unsqueezeShape.begin(),
unsqueezeShape.end());

auto reshapeDimAttr = rewriter.getI32ArrayAttr(reshapeDim);
ttir::ReshapeOp reshapeOp = ttmlir::utils::createDPSOp<ttir::ReshapeOp>(
rewriter, srcOp.getLoc(), unsqueezeShape, desiredType.getElementType(),
desiredType.getEncoding(), input, reshapeDimAttr);

::llvm::ArrayRef<int64_t> inputShape = unsqueezeShape;
::llvm::ArrayRef<int64_t> outputShape = desiredType.getShape();

SmallVector<int64_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int64_t>(inputShape, outputShape);

return ttmlir::utils::createDPSOp<ttir::BroadcastOp>(
rewriter, srcOp->getLoc(), desiredType, reshapeOp, broadcastShape);
}
};
} // namespace
Expand Down
65 changes: 64 additions & 1 deletion test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_transpose attributes {} {
module @jit_clamp attributes {} {
func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func.func public @test_clamp_constant
%cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32>
%cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
Expand All @@ -12,7 +13,38 @@ module @jit_transpose attributes {} {
return %0 : tensor<4xf32>
}

func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape
%cst = arith.constant dense<3.0> : tensor<1xf64>
%cst_0 = arith.constant dense<6> : tensor<1xi64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
%2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16>
%3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x16xbf16>]]
// CHECK: "ttir.clamp"(%arg0, %[[EMPTY]])
// CHECK-SAME: max = 6.000000e+00 : f32, min = 3.000000e+00 : f32
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%4 = stablehlo.clamp %1, %arg0, %3 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %4 : tensor<1x16xbf16>
}

func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast
%cst = stablehlo.constant dense<2.000000e+00> : tensor<bf16>
%cst_0 = stablehlo.constant dense<5.000000e+00> : tensor<bf16>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
%1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x32xbf16>]]
// CHECK: "ttir.clamp"(%arg0, %[[EMPTY]])
// CHECK-SAME: max = 5.000000e+00 : f32, min = 2.000000e+00 : f32
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16>
return %2 : tensor<1x32xbf16>
}

func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func.func public @test_clamp_tensor
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
// CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
Expand All @@ -23,4 +55,35 @@ module @jit_transpose attributes {} {
// CHECK: return %[[MIN]] : [[TENSOR]]
return %0 : tensor<4xf32>
}

func.func public @test_clamp_tensor_constant(%arg0: tensor<1x16xbf16>, %arg1: tensor<bf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_tensor_constant(
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%cst = arith.constant dense<3.0> : tensor<1xf64>
// CHECK: %[[CAST:[0-9]+]] = "ttir.typecast"(%[[CONSTANT]],
// CHECK-SAME: (tensor<1xf32>, tensor<1xbf16>) -> tensor<1xbf16>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
// CHECK: %[[RESHAPE0:[0-9]+]] = "ttir.reshape"(%[[CAST]],
// CHECK-SAME: shape = [1 : i32]
// CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[RESHAPE1:[0-9]+]] = "ttir.reshape"(%[[RESHAPE0]],
// CHECK-SAME: shape = [1 : i32, 1 : i32]
// CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16>
// CHECK: %[[MIN:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE1]]
// CHECK-SAME: {broadcast_dimensions = array<i64: 1, 16>}
// CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: %[[RESHAPE2:[0-9]+]] = "ttir.reshape"(%arg1,
// CHECK-SAME: <{shape = [1 : i32, 1 : i32]}>
// CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16>
// CHECK: %[[MAX:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE2]],
// CHECK-SAME: <{broadcast_dimensions = array<i64: 1, 16>}>
// CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: %[[ARG:[0-9]+]] = "ttir.maximum"(%[[MIN]], %arg0,
// CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: "ttir.minimum"(%[[ARG]], %[[MAX]],
// CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
%2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %2 : tensor<1x16xbf16>
}
}
55 changes: 51 additions & 4 deletions test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s \
// RUN: --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_transpose attributes {} {
func.func public @test_clamp_constant(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK-LABEL: func.func public @test_clamp_constant
// CHECK-LABEL: func.func public @test_clamp_constant(
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: tensor<64x128xf32,
Expand All @@ -19,8 +19,38 @@ module @jit_transpose attributes {} {
return %0 : tensor<64x128xf32>
}

func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape
%cst = arith.constant dense<3.0> : tensor<1xf64>
%cst_0 = arith.constant dense<6> : tensor<1xi64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
%2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16>
%3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 6.000000e+00 : f32, min = 3.000000e+00 : f32}
// CHECK-SAME: tensor<1x16xbf16,
// CHECK-SAME: -> tensor<1x16xbf16,
%4 = stablehlo.clamp %1, %arg0, %3 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %4 : tensor<1x16xbf16>
}

func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast
%cst = stablehlo.constant dense<2.000000e+00> : tensor<bf16>
%cst_0 = stablehlo.constant dense<5.000000e+00> : tensor<bf16>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
%1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 5.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: tensor<1x32xbf16,
// CHECK-SAME: -> tensor<1x32xbf16,
%2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16>
return %2 : tensor<1x32xbf16>
}

func.func public @test_clamp_tensor(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>, %arg2: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK-LABEL: func.func public @test_clamp_tensor
// CHECK-LABEL: func.func public @test_clamp_tensor(
// CHECK: %[[MAX:.*]] = "ttnn.maximum"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
Expand All @@ -32,4 +62,21 @@ module @jit_transpose attributes {} {
%0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<64x128xf32>
return %0 : tensor<64x128xf32>
}

func.func public @test_clamp_tensor_constant(%arg0: tensor<32x32xbf16>, %arg1: tensor<bf16>) -> tensor<32x32xbf16> {
// CHECK-LABEL: func.func public @test_clamp_tensor_constant(
%cst = arith.constant dense<3.0> : tensor<1xf64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[MAX:.*]] = "ttnn.maximum"
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: -> tensor<32x32xbf16,
// CHECK: "ttnn.minimum"(%[[MAX]]
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: -> tensor<32x32xbf16,
%2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor<bf16>, tensor<32x32xbf16>, tensor<bf16>) -> tensor<32x32xbf16>
return %2 : tensor<32x32xbf16>
}
}
Loading