Skip to content

Commit

Permalink
Output shape verification for reduce ops (#2374)
Browse files Browse the repository at this point in the history
### Ticket
Closes #2369

### Problem description
Some erroneous shapes would pass verification and only fail in the
runtime due to wrong interpretation of `dim_arg` when its value is
`nullopt`. There is an example in the linked issue.

### What's changed
- Added verification when `dim_arg` is nullopt in TTNN
- Added a shape verification in TTIR
- Added negative tests

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
azecevicTT authored and odjuricicTT committed Mar 8, 2025
1 parent c7655a5 commit 7598a95
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 74 deletions.
84 changes: 59 additions & 25 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2983,31 +2983,57 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,

// Common verifier for all Reduce ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
if (!reduceDims) {
return mlir::success();
}
verifyReduceOp(llvm::function_ref<mlir::InFlightDiagnostic()> emitOpError,
mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims, bool keepDim,
::llvm::ArrayRef<int64_t> specifiedOutputShape) {

int64_t inputTensorRank = inputType.getRank();

llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
int64_t reduceDimInt = mlir::cast<mlir::IntegerAttr>(reduceDim).getInt();
if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) {
return reduceOp->emitOpError("Reduce dimensions are out of range");
llvm::BitVector reduceDimsMask(inputTensorRank, false);
if (!reduceDims) {
reduceDimsMask.set();
} else {
llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
int64_t reduceDimInt = mlir::cast<mlir::IntegerAttr>(reduceDim).getInt();
if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) {
return emitOpError() << "Reduce dimension " << reduceDimInt
<< " is out of range for input tensor of rank "
<< inputTensorRank;
}
uniqueReduceDims.insert(reduceDimInt);
reduceDimsMask.set((reduceDimInt + inputTensorRank) % inputTensorRank);
}

if (uniqueReduceDims.size() != reduceDims->size()) {
return emitOpError() << "Reduce dimensions are not unique";
}
}

// Check that the output shape is valid.
llvm::SmallVector<int64_t> expectedOutputShape;
for (int64_t index = 0; index < inputTensorRank; ++index) {
if (!reduceDimsMask[index]) {
expectedOutputShape.push_back(inputType.getDimSize(index));
} else if (keepDim) {
expectedOutputShape.push_back(1);
}
uniqueReduceDims.insert(reduceDimInt);
}

if (uniqueReduceDims.size() != reduceDims->size()) {
return reduceOp->emitOpError("Reduce dimensions are not unique");
// Cover edge case where all dims are reduced, and keepDim==false.
if (expectedOutputShape.empty() && !keepDim) {
expectedOutputShape.push_back(1);
}

// TODO(mrakita): Add a check that depending on inputShape, reduceDims and
// keepDim computes the expected output shape and checks if it matches the
// actual output shape. Tracked by:
// https://github.com/tenstorrent/tt-mlir/issues/1639
// Finally, compare shapes.
if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) {
return emitOpError() << "Expected output shape ("
<< ttmlir::utils::join(expectedOutputShape, ", ")
<< "), got ("
<< ttmlir::utils::join(specifiedOutputShape, ", ")
<< ")";
}

return mlir::success();
}
Expand All @@ -3025,7 +3051,8 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3041,7 +3068,8 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MeanOp verification.
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3057,7 +3085,8 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// SumOp verification.
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3073,7 +3102,8 @@ void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MinOp verification.
::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3089,7 +3119,8 @@ void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// ProdOp verification.
::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3105,7 +3136,8 @@ void mlir::tt::ttir::ReduceAndOp::buildGenericRegion(

// ReduceAndOp verification.
::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3121,7 +3153,8 @@ void mlir::tt::ttir::ReduceOrOp::buildGenericRegion(

// ReduceOrOp verification.
::mlir::LogicalResult mlir::tt::ttir::ReduceOrOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -3144,7 +3177,8 @@ ::mlir::LogicalResult mlir::tt::ttir::ArgMaxOp::verify() {
<< dimArg->size() << ".";
}

return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(), getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand Down
80 changes: 41 additions & 39 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Utils.h"
Expand Down Expand Up @@ -1703,49 +1702,48 @@ ::mlir::LogicalResult UpsampleOp::verify() {

// Common verifier for all Reduction ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
verifyReduceOp(llvm::function_ref<mlir::InFlightDiagnostic()> emitOpError,
mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims, bool keepDim,
::llvm::ArrayRef<int64_t> specifiedOutputShape) {
if (!reduceDims) {
return mlir::success();
}

int64_t inputTensorRank = inputType.getRank();

// Calculate output shape for given args.
//
llvm::SmallVector<int64_t> calculatedOutputShape;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
bool isDimInReduceDims =
llvm::any_of(*reduceDims, [i, inputTensorRank](mlir::Attribute attr) {
int64_t reduceDim = mlir::cast<mlir::IntegerAttr>(attr).getInt();
// Check for match even if negative dim is used.
//
return reduceDim == i || (reduceDim + inputTensorRank) == i;
});

// If dim is being reduced on, the dim will have size of 1 if keepDim==true,
// otherwise the dim is erased.
//
if (!isDimInReduceDims) {
calculatedOutputShape.push_back(inputType.getDimSize(i));
llvm::BitVector reduceDimsMask(inputTensorRank, false);
if (reduceDims) {
for (mlir::Attribute attr : *reduceDims) {
int64_t reduceDim = mlir::cast<mlir::IntegerAttr>(attr).getInt();
// Normalize range to [0, inputTensorRank).
if (reduceDim < 0) {
reduceDim += inputTensorRank;
}
reduceDimsMask.set(reduceDim);
}
} else {
reduceDimsMask.set();
}

llvm::SmallVector<int64_t> expectedOutputShape;
for (int64_t index = 0; index < inputTensorRank; ++index) {
if (!reduceDimsMask[index]) {
expectedOutputShape.push_back(inputType.getDimSize(index));
} else if (keepDim) {
calculatedOutputShape.push_back(1);
expectedOutputShape.push_back(1);
}
}

// Cover edge case where all dims are reduced, and keepDim==false.
if (calculatedOutputShape.size() == 0 && keepDim == false) {
calculatedOutputShape.push_back(1);
if (expectedOutputShape.empty() && !keepDim) {
expectedOutputShape.push_back(1);
}

// Finally, compare shapes.
//
if (!llvm::equal(specifiedOutputShape, calculatedOutputShape)) {
return reduceOp->emitOpError(
"Expected output shape (" +
ttmlir::utils::join(specifiedOutputShape, ", ") + "), got (" +
ttmlir::utils::join(calculatedOutputShape, ", ") + ")");
if (!llvm::equal(specifiedOutputShape, expectedOutputShape)) {
return emitOpError() << "Expected output shape ("
<< ttmlir::utils::join(expectedOutputShape, ", ")
<< "), got ("
<< ttmlir::utils::join(specifiedOutputShape, ", ")
<< ")";
}

return mlir::success();
Expand Down Expand Up @@ -1779,8 +1777,9 @@ static mlir::LogicalResult verifyReduceProdOp(mlir::Operation *reduceOp,

// MaxOp verification.
::mlir::LogicalResult MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(),
getKeepDim(), getResult().getType().getShape());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(),
getResult().getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1789,8 +1788,9 @@ ::mlir::LogicalResult MaxOp::verify() {

// MeanOp verification.
::mlir::LogicalResult MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(),
getKeepDim(), getResult().getType().getShape());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(),
getResult().getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1799,8 +1799,9 @@ ::mlir::LogicalResult MeanOp::verify() {

// SumOp verification.
::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(),
getKeepDim(), getResult().getType().getShape());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(),
getResult().getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1809,8 +1810,9 @@ ::mlir::LogicalResult SumOp::verify() {

// MinOp verification.
::mlir::LogicalResult MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg(),
getKeepDim(), getResult().getType().getShape());
return verifyReduceOp([&]() { return emitOpError(); }, getInput().getType(),
getDimArg(), getKeepDim(),
getResult().getType().getShape());
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for reduce ops

// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range
// CHECK: error: 'ttir.sum' op Reduce dimension 2 is out of range for input tensor of rank 2
func.func public @test_reduce_add_invalid_dim_high(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for reduce ops

// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range
// CHECK: error: 'ttir.sum' op Reduce dimension -3 is out of range for input tensor of rank 2
func.func public @test_reduce_add_invalid_dim_low(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
Expand Down
41 changes: 41 additions & 0 deletions test/ttmlir/Dialect/TTIR/reduction/negative_shape_mismatch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for reduce ops expected shape mismatch.

module {
// CHECK: error: 'ttir.sum' op Expected output shape (128, 16), got (128, 16, 1)
func.func public @shape_mismatch_0(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> {
%0 = tensor.empty() : tensor<128x16x1xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32>
return %1 : tensor<128x16x1xf32>
}
}

// -----
module {
// CHECK: error: 'ttir.sum' op Expected output shape (128, 16, 1), got (128, 16)
func.func public @shape_mismatch_1(%arg0: tensor<128x16x32xf32>) -> tensor<128x16xf32> {
%0 = tensor.empty() : tensor<128x16xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = true}> : (tensor<128x16x32xf32>, tensor<128x16xf32>) -> tensor<128x16xf32>
return %1 : tensor<128x16xf32>
}
}

// -----
module {
// CHECK: error: 'ttir.sum' op Expected output shape (1, 1, 1), got (128, 16, 1)
func.func public @shape_mismatch_2(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> {
%0 = tensor.empty() : tensor<128x16x1xf32>
%1 = "ttir.sum"(%arg0, %0) <{keep_dim = true}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32>
return %1 : tensor<128x16x1xf32>
}
}

// -----
module {
// CHECK: error: 'ttir.sum' op Expected output shape (1), got (128, 16, 1)
func.func public @shape_mismatch_3(%arg0: tensor<128x16x32xf32>) -> tensor<128x16x1xf32> {
%0 = tensor.empty() : tensor<128x16x1xf32>
%1 = "ttir.sum"(%arg0, %0) <{keep_dim = false}> : (tensor<128x16x32xf32>, tensor<128x16x1xf32>) -> tensor<128x16x1xf32>
return %1 : tensor<128x16x1xf32>
}
}
16 changes: 8 additions & 8 deletions test/ttmlir/Dialect/TTMetal/simple_attach_metal_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens
#layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>>
// CHECK-LABEL: func.func @reduceW(
// CHECK-SAME: %arg0: tensor<256x384xf32, #[[LAYOUT1:layout1]]>
// CHECK-SAME: ) -> tensor<256x32xf32, #[[LAYOUT2:layout2]]>
func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, #layout2> {
// CHECK: %[[C0:.*]] = tensor.empty() : tensor<256x32xf32, #[[LAYOUT2]]>
// CHECK: %[[C1:.*]] = "ttir.sum"(%arg0, %0) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<256x384xf32, #[[LAYOUT1]]>, tensor<256x32xf32, #[[LAYOUT2]]>) -> tensor<256x32xf32, #[[LAYOUT2]]>
%0 = tensor.empty() : tensor<256x32xf32, #layout2>
// CHECK-SAME: ) -> tensor<256x1xf32, #[[LAYOUT2:layout2]]>
func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x1xf32, #layout2> {
// CHECK: %[[C0:.*]] = tensor.empty() : tensor<256x1xf32, #[[LAYOUT2]]>
// CHECK: %[[C1:.*]] = "ttir.sum"(%arg0, %0) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<256x384xf32, #[[LAYOUT1]]>, tensor<256x1xf32, #[[LAYOUT2]]>) -> tensor<256x1xf32, #[[LAYOUT2]]>
%0 = tensor.empty() : tensor<256x1xf32, #layout2>
%1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>,
dim_arg = [-1: i32],
keep_dim = true}> :
(tensor<256x384xf32, #layout1>, tensor<256x32xf32, #layout2>) -> tensor<256x32xf32, #layout2>
// CHECK: return %[[C1]] : tensor<256x32xf32, #[[LAYOUT2]]>
return %1 : tensor<256x32xf32, #layout2>
(tensor<256x384xf32, #layout1>, tensor<256x1xf32, #layout2>) -> tensor<256x1xf32, #layout2>
// CHECK: return %[[C1]] : tensor<256x1xf32, #[[LAYOUT2]]>
return %1 : tensor<256x1xf32, #layout2>
}

0 comments on commit 7598a95

Please sign in to comment.