Skip to content

Commit 3e6b7bf

Browse files
committed
[AutoBump] Merge with fixes of 729f958 (Jan 22)
2 parents 11819d6 + 729f958 commit 3e6b7bf

File tree

9 files changed

+126
-378
lines changed

9 files changed

+126
-378
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
231231
//===----------------------------------------------------------------------===//
232232

233233
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
234-
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
234+
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
235235
TosaResolvableShapeOperands])> {
236236
}
237237

@@ -241,6 +241,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
241241
["inferReturnTypeComponents"]>,
242242
ResultsBroadcastableShape,
243243
TosaElementwiseOperator,
244+
SameOperandsAndResultRank,
244245
Pure])> {
245246
let assemblyFormat =
246247
"operands attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
303303
}
304304
}
305305

306+
/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
307+
/// and all nested regions
308+
void validateSameOperandsAndResultRankTrait(Region &region) {
309+
int errs = 0;
310+
for (auto &block : region) {
311+
for (auto &op : block) {
312+
if (!op.getDialect() ||
313+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314+
continue;
315+
if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316+
if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317+
errs++;
318+
}
319+
}
320+
WhileOp whileOp = dyn_cast<WhileOp>(op);
321+
IfOp ifOp = dyn_cast<IfOp>(op);
322+
if (whileOp || ifOp) {
323+
// recurse into whileOp's regions
324+
for (auto &next : op.getRegions()) {
325+
validateSameOperandsAndResultRankTrait(next);
326+
}
327+
}
328+
}
329+
}
330+
}
331+
306332
/// Pass that performs shape propagation across TOSA operations. This includes
307333
/// migrating to within the regions of if/while operations.
308334
struct TosaInferShapes
@@ -313,6 +339,8 @@ struct TosaInferShapes
313339
TypeModificationState state;
314340
propagateShapesInRegion(func.getBody(), state);
315341
state.commit();
342+
343+
validateSameOperandsAndResultRankTrait(func.getBody());
316344
}
317345
};
318346
} // namespace

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,11 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
4545
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4646
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4747
}
48+
49+
// -----
50+
51+
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
52+
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
53+
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
54+
return %0 : tensor<2x3x4xf32>
55+
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9494
// CHECK: } -> tensor<f32>
9595
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9696

97+
9798
// CHECK: return [[RESULT]] : tensor<f32>
9899
return %0 : tensor<f32>
99100
}
@@ -123,20 +124,20 @@ func.func @test_add_uint8(%arg0: tensor<ui8>, %arg1: tensor<ui8>) -> tensor<ui8>
123124
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
124125
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
125126

126-
// CHECK-LABEL: func.func @test_add_0d_broadcast(
127+
// CHECK-LABEL: func.func @test_add_2d_broadcast(
127128
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
128-
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
129-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
129+
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
130130
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
131-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
131+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
132132
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
133133
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
134134
// CHECK: linalg.yield %[[ADD]] : f32
135135
// CHECK: } -> tensor<2x1xf32>
136136
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
137137
// CHECK: }
138-
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
139-
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
138+
func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> {
139+
// tosa element-wise operators now require operands of equal ranks
140+
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
140141
return %0 : tensor<2x1xf32>
141142
}
142143

@@ -403,28 +404,6 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
403404

404405
// -----
405406

406-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
407-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
408-
// CHECK-LABEL: @test_add_2d_different_ranks
409-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
410-
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
411-
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
412-
413-
// CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
414-
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
415-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
416-
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
417-
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
418-
// CHECK: linalg.yield %[[VAL_4]] : f32
419-
// CHECK: } -> tensor<2x3x4xf32>
420-
%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
421-
422-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
423-
return %0 : tensor<2x3x4xf32>
424-
}
425-
426-
// -----
427-
428407
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
429408
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
430409
// CHECK-LABEL: @test_select_2d_one_dynamic

0 commit comments

Comments
 (0)