From 1a8c6130f60fe517fb722ab4309997ce7b638234 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Sun, 18 Feb 2024 10:17:03 +0100 Subject: [PATCH] [mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133) This patch aligns the shift Ops in `arith` with respective LLVM instructions. Specifically, shifting by an amount equal to the bitwidth of the operand is now defined to return poison. Relevant discussion: https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10 Relevant issue: https://github.com/llvm/llvm-project/issues/80960 --- .../include/mlir/Dialect/Arith/IR/ArithOps.td | 10 +++--- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 12 +++---- mlir/test/Dialect/Arith/canonicalize.mlir | 33 +++++++++++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 4babbe80e285f..c9df50d0395d1 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> { The `shli` operation shifts the integer value of the first operand to the left by the integer value of the second operand. The second operand is interpreted as unsigned. The low order bits are filled with zeros. If the value of the second - operand is greater than the bitwidth of the first operand, then the + operand is greater or equal than the bitwidth of the first operand, then the operation returns poison. This op supports `nuw`/`nsw` overflow flags which stands stand for @@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { The `shrui` operation shifts an integer value of the first operand to the right by the value of the second operand. The first operand is interpreted as unsigned, and the second operand is interpreted as unsigned. The high order bits are always - filled with zeros. If the value of the second operand is greater than the bitwidth - of the first operand, then the operation returns poison. + filled with zeros. If the value of the second operand is greater or equal than the + bitwidth of the first operand, then the operation returns poison. Example: @@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { and the second operand is interpreter as unsigned. The high order bits in the output are filled with copies of the most-significant bit of the shifted value (which means that the sign of the value is preserved). If the value of the second - operand is greater than bitwidth of the first operand, then the operation returns - poison. + operand is greater or equal than bitwidth of the first operand, then the operation + returns poison. Example: diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 275c2debe9a6f..0f71c19c23b65 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { // shli(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.shl(b); }); return bounded ? result : Attribute(); @@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { // shrui(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.lshr(b); }); return bounded ? result : Attribute(); @@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { // shrsi(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); - // Don't fold if shifting more than the bit width. + // Don't fold if shifting more or equal than the bit width. bool bounded = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { - bounded = b.ule(b.getBitWidth()); + bounded = b.ult(b.getBitWidth()); return a.ashr(b); }); return bounded ? result : Attribute(); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f128b13e9f732..cb98a10048a30 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShl3( +// CHECK: %[[res:.+]] = arith.shli +// CHECK: return %[[res]] +func.func @nofoldShl3() -> i64 { + %c1 = arith.constant 1 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shli %c1, %c64 : i64 + return %r : i64 +} + // CHECK-LABEL: @foldShru( // CHECK: %[[res:.+]] = arith.constant 2 : i64 // CHECK: return %[[res]] @@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShru3( +// CHECK: %[[res:.+]] = arith.shrui +// CHECK: return %[[res]] +func.func @nofoldShru3() -> i64 { + %c1 = arith.constant 8 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shrui %c1, %c64 : i64 + return %r : i64 +} + // CHECK-LABEL: @foldShrs( // CHECK: %[[res:.+]] = arith.constant 2 : i64 // CHECK: return %[[res]] @@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 { return %r : i64 } +// CHECK-LABEL: @nofoldShrs3( +// CHECK: %[[res:.+]] = arith.shrsi +// CHECK: return %[[res]] +func.func @nofoldShrs3() -> i64 { + %c1 = arith.constant 8 : i64 + %c64 = arith.constant 64 : i64 + // Note that this should return Poison in the future. + %r = arith.shrsi %c1, %c64 : i64 + return %r : i64 +} + // ----- // CHECK-LABEL: @test_negf(