From fd55dc0c89ea8360dfb6dcb237b33258a7d55192 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Wed, 5 Mar 2025 17:30:38 +0000 Subject: [PATCH 1/5] TTNN verification --- lib/Dialect/TTNN/IR/TTNNOps.cpp | 52 +++++++++++++++------------------ 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 07d9ac3825..235843f818 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -5,7 +5,6 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Utils.h" @@ -1706,46 +1705,43 @@ static mlir::LogicalResult verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, const std::optional &reduceDims, bool keepDim, ::llvm::ArrayRef specifiedOutputShape) { - if (!reduceDims) { - return mlir::success(); - } int64_t inputTensorRank = inputType.getRank(); - // Calculate output shape for given args. - // - llvm::SmallVector calculatedOutputShape; - for (int64_t i = 0; i < inputType.getRank(); ++i) { - bool isDimInReduceDims = - llvm::any_of(*reduceDims, [i, inputTensorRank](mlir::Attribute attr) { - int64_t reduceDim = mlir::cast(attr).getInt(); - // Check for match even if negative dim is used. - // - return reduceDim == i || (reduceDim + inputTensorRank) == i; - }); - - // If dim is being reduced on, the dim will have size of 1 if keepDim==true, - // otherwise the dim is erased. - // - if (!isDimInReduceDims) { - calculatedOutputShape.push_back(inputType.getDimSize(i)); + llvm::BitVector reduceDimsMask(inputTensorRank, false); + if (reduceDims) { + for (mlir::Attribute attr : *reduceDims) { + int64_t reduceDim = mlir::cast(attr).getInt(); + // Normalize range to [0, inputTensorRank). + if (reduceDim < 0) { + reduceDim += inputTensorRank; + } + reduceDimsMask.set(reduceDim); + } + } else { + reduceDimsMask.set(); + } + + llvm::SmallVector expectedOutputShape; + for (int64_t index = 0; index < inputTensorRank; ++index) { + if (!reduceDimsMask[index]) { + expectedOutputShape.push_back(inputType.getDimSize(index)); } else if (keepDim) { - calculatedOutputShape.push_back(1); + expectedOutputShape.push_back(1); } } // Cover edge case where all dims are reduced, and keepDim==false. - if (calculatedOutputShape.size() == 0 && keepDim == false) { - calculatedOutputShape.push_back(1); + if (expectedOutputShape.empty() && !keepDim) { + expectedOutputShape.push_back(1); } // Finally, compare shapes. - // - if (!llvm::equal(specifiedOutputShape, calculatedOutputShape)) { + if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) { return reduceOp->emitOpError( "Expected output shape (" + - ttmlir::utils::join(specifiedOutputShape, ", ") + "), got (" + - ttmlir::utils::join(calculatedOutputShape, ", ") + ")"); + ttmlir::utils::join(expectedOutputShape, ", ") + "), got (" + + ttmlir::utils::join(specifiedOutputShape, ", ") + ")"); } return mlir::success(); From ecc7acbd85f53a303814a1d4116e3f9d50299434 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Wed, 5 Mar 2025 17:38:40 +0000 Subject: [PATCH 2/5] TTIR verification --- lib/Dialect/TTIR/IR/TTIROps.cpp | 78 +++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 1c268a9c1c..691e9875dd 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2992,30 +2992,52 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, // Common verifier for all Reduce ops. static mlir::LogicalResult verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, - const std::optional &reduceDims) { - if (!reduceDims) { - return mlir::success(); - } + const std::optional &reduceDims, bool keepDim, + ::llvm::ArrayRef specifiedOutputShape) { int64_t inputTensorRank = inputType.getRank(); - llvm::SmallSet uniqueReduceDims; - for (mlir::Attribute reduceDim : *reduceDims) { - int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); - if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) { - return reduceOp->emitOpError("Reduce dimensions are out of range"); + llvm::BitVector reduceDimsMask(inputTensorRank, false); + if (reduceDims) { + llvm::SmallSet uniqueReduceDims; + for (mlir::Attribute reduceDim : *reduceDims) { + int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); + if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) { + return reduceOp->emitOpError("Reduce dimensions are out of range"); + } + uniqueReduceDims.insert(reduceDimInt); + reduceDimsMask.set((reduceDimInt + inputTensorRank) % inputTensorRank); + } + + if (uniqueReduceDims.size() != reduceDims->size()) { + return reduceOp->emitOpError("Reduce dimensions are not unique"); } - uniqueReduceDims.insert(reduceDimInt); + } else { + reduceDimsMask.set(); } - if (uniqueReduceDims.size() != reduceDims->size()) { - return reduceOp->emitOpError("Reduce dimensions are not unique"); + // Check that the output shape is valid. + llvm::SmallVector expectedOutputShape; + for (int64_t index = 0; index < inputTensorRank; ++index) { + if (!reduceDimsMask[index]) { + expectedOutputShape.push_back(inputType.getDimSize(index)); + } else if (keepDim) { + expectedOutputShape.push_back(1); + } } - // TODO(mrakita): Add a check that depending on inputShape, reduceDims and - // keepDim computes the expected output shape and checks if it matches the - // actual output shape. Tracked by: - // https://github.com/tenstorrent/tt-mlir/issues/1639 + // Cover edge case where all dims are reduced, and keepDim==false. + if (expectedOutputShape.empty() && !keepDim) { + expectedOutputShape.push_back(1); + } + + // Finally, compare shapes. + if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) { + return reduceOp->emitOpError( + "Expected output shape (" + + ttmlir::utils::join(expectedOutputShape, ", ") + "), got (" + + ttmlir::utils::join(specifiedOutputShape, ", ") + ")"); + } return mlir::success(); } @@ -3033,7 +3055,8 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MaxOp verification. ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3049,7 +3072,8 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MeanOp verification. ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3065,7 +3089,8 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // SumOp verification. ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3081,7 +3106,8 @@ void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MinOp verification. ::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3097,7 +3123,8 @@ void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // ProdOp verification. ::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3113,7 +3140,8 @@ void mlir::tt::ttir::ReduceAndOp::buildGenericRegion( // ReduceAndOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3129,7 +3157,8 @@ void mlir::tt::ttir::ReduceOrOp::buildGenericRegion( // ReduceOrOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceOrOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3152,7 +3181,8 @@ ::mlir::LogicalResult mlir::tt::ttir::ArgMaxOp::verify() { << dimArg->size() << "."; } - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// From f118a37033ae1beb5b76b9abfcc67bb7f6b18a08 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Wed, 5 Mar 2025 18:17:22 +0000 Subject: [PATCH 3/5] Negative tests --- .../reduction/negative_shape_mismatch.mlir | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 test/ttmlir/Dialect/TTIR/reduction/negative_shape_mismatch.mlir diff --git a/test/ttmlir/Dialect/TTIR/reduction/negative_shape_mismatch.mlir b/test/ttmlir/Dialect/TTIR/reduction/negative_shape_mismatch.mlir new file mode 100644 index 0000000000..a7fd19d5ce --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/reduction/negative_shape_mismatch.mlir @@ -0,0 +1,41 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for reduce ops expected shape mismatch. + +module { + // CHECK: error: 'ttir.sum' op Expected output shape (128, 16), got (128, 16, 1) + func.func public @shape_mismatch_0(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> { + %0 = tensor.empty() : tensor<128x16x1xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32> + return %1 : tensor<128x16x1xf32> + } +} + +// ----- +module { + // CHECK: error: 'ttir.sum' op Expected output shape (128, 16, 1), got (128, 16) + func.func public @shape_mismatch_1(%arg0: tensor<128x16x32xf32>) -> tensor<128x16xf32> { + %0 = tensor.empty() : tensor<128x16xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = true}> : (tensor<128x16x32xf32>, tensor<128x16xf32>) -> tensor<128x16xf32> + return %1 : tensor<128x16xf32> + } +} + +// ----- +module { + // CHECK: error: 'ttir.sum' op Expected output shape (1, 1, 1), got (128, 16, 1) + func.func public @shape_mismatch_2(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> { + %0 = tensor.empty() : tensor<128x16x1xf32> + %1 = "ttir.sum"(%arg0, %0) <{keep_dim = true}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32> + return %1 : tensor<128x16x1xf32> + } +} + +// ----- +module { + // CHECK: error: 'ttir.sum' op Expected output shape (1), got (128, 16, 1) + func.func public @shape_mismatch_3(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> { + %0 = tensor.empty() : tensor<128x16x1xf32> + %1 = "ttir.sum"(%arg0, %0) <{keep_dim = false}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32> + return %1 : tensor<128x16x1xf32> + } +} From 1e9e8779255a288ffa19f361d0260e098eaa422c Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Thu, 6 Mar 2025 16:53:32 +0000 Subject: [PATCH 4/5] address comment --- lib/Dialect/TTIR/IR/TTIROps.cpp | 38 ++++++++++--------- .../TTMetal/simple_attach_metal_layout.mlir | 16 ++++---- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 691e9875dd..bc85d6cbba 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2991,29 +2991,30 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, // Common verifier for all Reduce ops. static mlir::LogicalResult -verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, +verifyReduceOp(mlir::InFlightDiagnostic &&emitOpError, + mlir::RankedTensorType inputType, const std::optional &reduceDims, bool keepDim, ::llvm::ArrayRef specifiedOutputShape) { int64_t inputTensorRank = inputType.getRank(); llvm::BitVector reduceDimsMask(inputTensorRank, false); - if (reduceDims) { + if (!reduceDims) { + reduceDimsMask.set(); + } else { llvm::SmallSet uniqueReduceDims; for (mlir::Attribute reduceDim : *reduceDims) { int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) { - return reduceOp->emitOpError("Reduce dimensions are out of range"); + return emitOpError << "Reduce dimensions are out of range"; } uniqueReduceDims.insert(reduceDimInt); reduceDimsMask.set((reduceDimInt + inputTensorRank) % inputTensorRank); } if (uniqueReduceDims.size() != reduceDims->size()) { - return reduceOp->emitOpError("Reduce dimensions are not unique"); + return emitOpError << "Reduce dimensions are not unique"; } - } else { - reduceDimsMask.set(); } // Check that the output shape is valid. @@ -3033,10 +3034,11 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, // Finally, compare shapes. if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) { - return reduceOp->emitOpError( - "Expected output shape (" + - ttmlir::utils::join(expectedOutputShape, ", ") + "), got (" + - ttmlir::utils::join(specifiedOutputShape, ", ") + ")"); + return emitOpError << "Expected output shape (" + << ttmlir::utils::join(expectedOutputShape, ", ") + << "), got (" + << ttmlir::utils::join(specifiedOutputShape, ", ") + << ")"; } return mlir::success(); @@ -3055,7 +3057,7 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MaxOp verification. ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3072,7 +3074,7 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MeanOp verification. ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3089,7 +3091,7 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // SumOp verification. ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3106,7 +3108,7 @@ void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MinOp verification. ::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3123,7 +3125,7 @@ void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // ProdOp verification. ::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3140,7 +3142,7 @@ void mlir::tt::ttir::ReduceAndOp::buildGenericRegion( // ReduceAndOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3157,7 +3159,7 @@ void mlir::tt::ttir::ReduceOrOp::buildGenericRegion( // ReduceOrOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceOrOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } @@ -3181,7 +3183,7 @@ ::mlir::LogicalResult mlir::tt::ttir::ArgMaxOp::verify() { << dimArg->size() << "."; } - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), + return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), getKeepDim(), getType().getShape()); } diff --git a/test/ttmlir/Dialect/TTMetal/simple_attach_metal_layout.mlir b/test/ttmlir/Dialect/TTMetal/simple_attach_metal_layout.mlir index 7d0ff83ddd..7dd901e586 100644 --- a/test/ttmlir/Dialect/TTMetal/simple_attach_metal_layout.mlir +++ b/test/ttmlir/Dialect/TTMetal/simple_attach_metal_layout.mlir @@ -17,15 +17,15 @@ func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens #layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> // CHECK-LABEL: func.func @reduceW( // CHECK-SAME: %arg0: tensor<256x384xf32, #[[LAYOUT1:layout1]]> -// CHECK-SAME: ) -> tensor<256x32xf32, #[[LAYOUT2:layout2]]> -func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, #layout2> { - // CHECK: %[[C0:.*]] = tensor.empty() : tensor<256x32xf32, #[[LAYOUT2]]> - // CHECK: %[[C1:.*]] = "ttir.sum"(%arg0, %0) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<256x384xf32, #[[LAYOUT1]]>, tensor<256x32xf32, #[[LAYOUT2]]>) -> tensor<256x32xf32, #[[LAYOUT2]]> - %0 = tensor.empty() : tensor<256x32xf32, #layout2> +// CHECK-SAME: ) -> tensor<256x1xf32, #[[LAYOUT2:layout2]]> +func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x1xf32, #layout2> { + // CHECK: %[[C0:.*]] = tensor.empty() : tensor<256x1xf32, #[[LAYOUT2]]> + // CHECK: %[[C1:.*]] = "ttir.sum"(%arg0, %0) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<256x384xf32, #[[LAYOUT1]]>, tensor<256x1xf32, #[[LAYOUT2]]>) -> tensor<256x1xf32, #[[LAYOUT2]]> + %0 = tensor.empty() : tensor<256x1xf32, #layout2> %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32], keep_dim = true}> : - (tensor<256x384xf32, #layout1>, tensor<256x32xf32, #layout2>) -> tensor<256x32xf32, #layout2> - // CHECK: return %[[C1]] : tensor<256x32xf32, #[[LAYOUT2]]> - return %1 : tensor<256x32xf32, #layout2> + (tensor<256x384xf32, #layout1>, tensor<256x1xf32, #layout2>) -> tensor<256x1xf32, #layout2> + // CHECK: return %[[C1]] : tensor<256x1xf32, #[[LAYOUT2]]> + return %1 : tensor<256x1xf32, #layout2> } From 43cad86cbf52bd240b8ef9148f7ce36c904d4343 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Thu, 6 Mar 2025 17:26:11 +0000 Subject: [PATCH 5/5] reduce dim message + InFlightDiagnostic fix --- lib/Dialect/TTIR/IR/TTIROps.cpp | 50 ++++++++++--------- lib/Dialect/TTNN/IR/TTNNOps.cpp | 32 +++++++----- .../reduction/negative_invalid_dim_high.mlir | 2 +- .../reduction/negative_invalid_dim_low.mlir | 2 +- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index bc85d6cbba..c1f38c1d94 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2991,7 +2991,7 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, // Common verifier for all Reduce ops. static mlir::LogicalResult -verifyReduceOp(mlir::InFlightDiagnostic &&emitOpError, +verifyReduceOp(llvm::function_ref emitOpError, mlir::RankedTensorType inputType, const std::optional &reduceDims, bool keepDim, ::llvm::ArrayRef specifiedOutputShape) { @@ -3006,14 +3006,16 @@ verifyReduceOp(mlir::InFlightDiagnostic &&emitOpError, for (mlir::Attribute reduceDim : *reduceDims) { int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) { - return emitOpError << "Reduce dimensions are out of range"; + return emitOpError() << "Reduce dimension " << reduceDimInt + << " is out of range for input tensor of rank " + << inputTensorRank; } uniqueReduceDims.insert(reduceDimInt); reduceDimsMask.set((reduceDimInt + inputTensorRank) % inputTensorRank); } if (uniqueReduceDims.size() != reduceDims->size()) { - return emitOpError << "Reduce dimensions are not unique"; + return emitOpError() << "Reduce dimensions are not unique"; } } @@ -3034,11 +3036,11 @@ verifyReduceOp(mlir::InFlightDiagnostic &&emitOpError, // Finally, compare shapes. if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) { - return emitOpError << "Expected output shape (" - << ttmlir::utils::join(expectedOutputShape, ", ") - << "), got (" - << ttmlir::utils::join(specifiedOutputShape, ", ") - << ")"; + return emitOpError() << "Expected output shape (" + << ttmlir::utils::join(expectedOutputShape, ", ") + << "), got (" + << ttmlir::utils::join(specifiedOutputShape, ", ") + << ")"; } return mlir::success(); @@ -3057,8 +3059,8 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MaxOp verification. ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3074,8 +3076,8 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MeanOp verification. ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3091,8 +3093,8 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // SumOp verification. ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3108,8 +3110,8 @@ void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MinOp verification. ::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3125,8 +3127,8 @@ void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // ProdOp verification. ::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3142,8 +3144,8 @@ void mlir::tt::ttir::ReduceAndOp::buildGenericRegion( // ReduceAndOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3159,8 +3161,8 @@ void mlir::tt::ttir::ReduceOrOp::buildGenericRegion( // ReduceOrOp verification. ::mlir::LogicalResult mlir::tt::ttir::ReduceOrOp::verify() { - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// @@ -3183,8 +3185,8 @@ ::mlir::LogicalResult mlir::tt::ttir::ArgMaxOp::verify() { << dimArg->size() << "."; } - return verifyReduceOp(emitOpError(), getInput().getType(), getDimArg(), - getKeepDim(), getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), getType().getShape()); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 235843f818..476530265c 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -1702,7 +1702,8 @@ ::mlir::LogicalResult UpsampleOp::verify() { // Common verifier for all Reduction ops. static mlir::LogicalResult -verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, +verifyReduceOp(llvm::function_ref emitOpError, + mlir::RankedTensorType inputType, const std::optional &reduceDims, bool keepDim, ::llvm::ArrayRef specifiedOutputShape) { @@ -1738,10 +1739,11 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, // Finally, compare shapes. if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) { - return reduceOp->emitOpError( - "Expected output shape (" + - ttmlir::utils::join(expectedOutputShape, ", ") + "), got (" + - ttmlir::utils::join(specifiedOutputShape, ", ") + ")"); + return emitOpError() << "Expected output shape (" + << ttmlir::utils::join(expectedOutputShape, ", ") + << "), got (" + << ttmlir::utils::join(specifiedOutputShape, ", ") + << ")"; } return mlir::success(); @@ -1775,8 +1777,9 @@ static mlir::LogicalResult verifyReduceProdOp(mlir::Operation *reduceOp, // MaxOp verification. ::mlir::LogicalResult MaxOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), - getKeepDim(), getResult().getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), + getResult().getType().getShape()); } //===----------------------------------------------------------------------===// @@ -1785,8 +1788,9 @@ ::mlir::LogicalResult MaxOp::verify() { // MeanOp verification. ::mlir::LogicalResult MeanOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), - getKeepDim(), getResult().getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), + getResult().getType().getShape()); } //===----------------------------------------------------------------------===// @@ -1795,8 +1799,9 @@ ::mlir::LogicalResult MeanOp::verify() { // SumOp verification. ::mlir::LogicalResult SumOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), - getKeepDim(), getResult().getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), + getResult().getType().getShape()); } //===----------------------------------------------------------------------===// @@ -1805,8 +1810,9 @@ ::mlir::LogicalResult SumOp::verify() { // MinOp verification. ::mlir::LogicalResult MinOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(), - getKeepDim(), getResult().getType().getShape()); + return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(), + getDimArg(), getKeepDim(), + getResult().getType().getShape()); } //===----------------------------------------------------------------------===// diff --git a/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_high.mlir b/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_high.mlir index 565745d057..77f10fb099 100644 --- a/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_high.mlir +++ b/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_high.mlir @@ -1,7 +1,7 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s // Negative tests for reduce ops -// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range +// CHECK: error: 'ttir.sum' op Reduce dimension 2 is out of range for input tensor of rank 2 func.func public @test_reduce_add_invalid_dim_high(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> { %0 = tensor.empty() : tensor<128xf32> %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32> diff --git a/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_low.mlir b/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_low.mlir index bd4a237d46..a6a19ce9cb 100644 --- a/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_low.mlir +++ b/test/ttmlir/Dialect/TTIR/reduction/negative_invalid_dim_low.mlir @@ -1,7 +1,7 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s // Negative tests for reduce ops -// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range +// CHECK: error: 'ttir.sum' op Reduce dimension -3 is out of range for input tensor of rank 2 func.func public @test_reduce_add_invalid_dim_low(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> { %0 = tensor.empty() : tensor<128xf32> %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>