Skip to content

Commit a8bfe24

Browse files
committed
[AutoBump] Merge with fixes of 8388040 (Jan 23)
2 parents d21dbbb + 8388040 commit a8bfe24

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4242

4343
let arguments = (ins
4444
Tosa_Tensor: $input,
45-
I32Attr: $axis
45+
I32Attr: $axis,
46+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
4647
);
4748

4849
let results = (outs
@@ -287,7 +288,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
287288

288289
Tosa_IntArrayAttr2:$kernel,
289290
Tosa_IntArrayAttr2:$stride,
290-
Tosa_IntArrayAttr4:$pad
291+
Tosa_IntArrayAttr4:$pad,
292+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
291293
);
292294

293295
let results = (outs
@@ -388,7 +390,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
388390
I64Attr:$min_int,
389391
I64Attr:$max_int,
390392
Tosa_FloatAttr:$min_fp,
391-
Tosa_FloatAttr:$max_fp
393+
Tosa_FloatAttr:$max_fp,
394+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
392395
);
393396

394397
let results = (outs
@@ -752,7 +755,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
752755

753756
let arguments = (ins
754757
Tosa_Tensor:$input1,
755-
Tosa_Tensor:$input2
758+
Tosa_Tensor:$input2,
759+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
756760
);
757761

758762
let results = (outs
@@ -777,7 +781,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
777781

778782
let arguments = (ins
779783
Tosa_Tensor:$input1,
780-
Tosa_Tensor:$input2
784+
Tosa_Tensor:$input2,
785+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
781786
);
782787

783788
let results = (outs
@@ -1390,7 +1395,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13901395

13911396
let arguments = (ins
13921397
Tosa_Tensor:$input,
1393-
I32Attr:$axis
1398+
I32Attr:$axis,
1399+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
13941400
);
13951401

13961402
let results = (outs
@@ -1430,7 +1436,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
14301436

14311437
let arguments = (ins
14321438
Tosa_Tensor:$input,
1433-
I32Attr:$axis
1439+
I32Attr:$axis,
1440+
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
14341441
);
14351442

14361443
let results = (outs

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
208208
//===----------------------------------------------------------------------===//
209209
// Iterable attributes.
210210
//===----------------------------------------------------------------------===//
211+
// Defined in `section 3. Enumerations` of the TOSA specification.
212+
211213
// Supported regimes for tosa.resize.
212214
def Tosa_ResizeTypeAttr : StringBasedAttr<
213215
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
214216
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
215217
"Supported resize/upsampling strategies">;
216218

219+
// Supported NaN propagation strategies.
220+
def Tosa_NanPropagationAttr : StringBasedAttr<
221+
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
222+
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
223+
"Supported NaN propagation strategies">;
224+
217225
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
218226

219227
// Tensor to buffer types.

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -658,33 +658,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
658658
}
659659
};
660660

661+
// Attempts the following transformation:
662+
//
663+
// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
664+
// tensor X the following identity holds:
665+
//
666+
// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
667+
//
668+
// subject to the following valid NaN propagation semantics:
669+
// --------------------------------------------
670+
// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
671+
// |-------------|--------------|-------------|
672+
// | PROPAGATE | PROPAGATE | PROPAGATE |
673+
// | PROPAGATE | IGNORE | IGNORE |
674+
// | IGNORE | PROPAGATE | INVALID |
675+
// | IGNORE | IGNORE | IGNORE |
676+
// |------------------------------------------|
677+
661678
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
662679
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
663680

681+
// Helper structure to describe the range of a clamp operation.
682+
template <typename T>
683+
struct ClampRange {
684+
ClampRange(const T &start, const T &end) : start(start), end(end) {}
685+
T start;
686+
T end;
687+
688+
// Helper function to determine if two Clamp ranges intersect.
689+
bool intersects(const ClampRange<T> &otherRange) {
690+
return start < otherRange.end && otherRange.start < end;
691+
}
692+
};
693+
664694
LogicalResult matchAndRewrite(tosa::ClampOp op,
665695
PatternRewriter &rewriter) const override {
666-
Value input = op.getInput();
667-
668-
Operation *definingOp = input.getDefiningOp();
669-
if (!definingOp)
696+
// Check the input to the CLAMP op is itself a CLAMP.
697+
auto clampOp =
698+
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
699+
if (!clampOp)
670700
return failure();
671701

672-
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
673-
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
674-
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
702+
// Check we have a valid NaN propagation combination.
703+
const auto opNanMode = op.getNanMode();
704+
const auto clampNanMode = clampOp.getNanMode();
705+
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
706+
return failure();
675707

676-
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
677-
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
708+
// Check we have intersecting ranges.
709+
const auto opMinInt = op.getMinInt();
710+
const auto opMaxInt = op.getMaxInt();
711+
const auto clampOpMinInt = clampOp.getMinInt();
712+
const auto clampOpMaxInt = clampOp.getMaxInt();
713+
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
714+
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
715+
if (!opRangeIntRange.intersects(clampRangeIntRange))
716+
return failure();
678717

679-
rewriter.replaceOpWithNewOp<ClampOp>(
680-
op, {op->getLoc(), clampOp->getLoc()}, op.getType(),
681-
clampOp.getInput(), rewriter.getI64IntegerAttr(minInt),
682-
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
683-
rewriter.getF32FloatAttr(maxFp));
684-
return success();
685-
}
718+
const auto opMinFloat = op.getMinFp();
719+
const auto opMaxFloat = op.getMaxFp();
720+
const auto clampOpMinFloat = clampOp.getMinFp();
721+
const auto clampOpMaxFloat = clampOp.getMaxFp();
722+
ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
723+
ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
724+
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
725+
return failure();
686726

687-
return failure();
727+
// Run the transformation.
728+
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
729+
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
730+
const auto minInt = std::max(opMinInt, clampOpMinInt);
731+
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
732+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
733+
op, {op->getLoc(), clampOp->getLoc()}, op.getType(), clampOp.getInput(),
734+
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
735+
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
736+
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
737+
: opNanMode));
738+
return success();
688739
}
689740
};
690741

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,58 @@ func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>,
233233

234234
// -----
235235

236+
// CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
237+
func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
238+
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
239+
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
240+
%0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
241+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
242+
return %1 : tensor<4xi8>
243+
}
244+
245+
// -----
246+
247+
// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp
248+
func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
249+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
250+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
251+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
252+
return %1 : tensor<4xi8>
253+
}
254+
255+
// -----
256+
257+
// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp
258+
func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
259+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
260+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
261+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
262+
return %1 : tensor<4xi8>
263+
}
264+
265+
// -----
266+
267+
// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp
268+
func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
269+
// CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
270+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
271+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
272+
return %1 : tensor<4xi8>
273+
}
274+
275+
// -----
276+
277+
// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
278+
func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
279+
// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8>
280+
// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
281+
%0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
282+
%1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
283+
return %1 : tensor<4xi8>
284+
}
285+
286+
// -----
287+
236288
// CHECK-LABEL: @concat_fold
237289
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
238290
// CHECK: return %arg0

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
180180
return %0 : tensor<13x21x3xf32>
181181
}
182182

183+
// -----
184+
// CHECK-LABEL: clamp_propagate
185+
func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
186+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
187+
return %0 : tensor<13x21x3xf32>
188+
}
189+
190+
// -----
191+
// CHECK-LABEL: clamp_ignore
192+
func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
193+
%0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
194+
return %0 : tensor<13x21x3xf32>
195+
}
196+
183197
// -----
184198
// CHECK-LABEL: clamp_f16
185199
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {

0 commit comments

Comments
 (0)