Skip to content

Commit

Permalink
[mlir][arith] Align shift Ops with LLVM instructions on allowed shift…
Browse files Browse the repository at this point in the history
… amounts (llvm#82133)

This patch aligns the shift Ops in `arith` with respective LLVM instructions.
Specifically, shifting by an amount equal to the bitwidth of the operand
is now defined to return poison.

Relevant discussion:
https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10
Relevant issue: llvm#80960
  • Loading branch information
ubfx authored Feb 18, 2024
1 parent 833fea4 commit 1a8c613
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
The `shli` operation shifts the integer value of the first operand to the left
by the integer value of the second operand. The second operand is interpreted as
unsigned. The low order bits are filled with zeros. If the value of the second
operand is greater than the bitwidth of the first operand, then the
operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.

This op supports `nuw`/`nsw` overflow flags which stands stand for
Expand Down Expand Up @@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
The `shrui` operation shifts an integer value of the first operand to the right
by the value of the second operand. The first operand is interpreted as unsigned,
and the second operand is interpreted as unsigned. The high order bits are always
filled with zeros. If the value of the second operand is greater than the bitwidth
of the first operand, then the operation returns poison.
filled with zeros. If the value of the second operand is greater or equal than the
bitwidth of the first operand, then the operation returns poison.

Example:

Expand All @@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
and the second operand is interpreter as unsigned. The high order bits in the
output are filled with copies of the most-significant bit of the shifted value
(which means that the sign of the value is preserved). If the value of the second
operand is greater than bitwidth of the first operand, then the operation returns
poison.
operand is greater or equal than bitwidth of the first operand, then the operation
returns poison.

Example:

Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
// shli(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.shl(b);
});
return bounded ? result : Attribute();
Expand All @@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
// shrui(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.lshr(b);
});
return bounded ? result : Attribute();
Expand All @@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
// shrsi(x, 0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
// Don't fold if shifting more or equal than the bit width.
bool bounded = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
bounded = b.ule(b.getBitWidth());
bounded = b.ult(b.getBitWidth());
return a.ashr(b);
});
return bounded ? result : Attribute();
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShl3(
// CHECK: %[[res:.+]] = arith.shli
// CHECK: return %[[res]]
func.func @nofoldShl3() -> i64 {
%c1 = arith.constant 1 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shli %c1, %c64 : i64
return %r : i64
}

// CHECK-LABEL: @foldShru(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
Expand Down Expand Up @@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShru3(
// CHECK: %[[res:.+]] = arith.shrui
// CHECK: return %[[res]]
func.func @nofoldShru3() -> i64 {
%c1 = arith.constant 8 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shrui %c1, %c64 : i64
return %r : i64
}

// CHECK-LABEL: @foldShrs(
// CHECK: %[[res:.+]] = arith.constant 2 : i64
// CHECK: return %[[res]]
Expand Down Expand Up @@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
return %r : i64
}

// CHECK-LABEL: @nofoldShrs3(
// CHECK: %[[res:.+]] = arith.shrsi
// CHECK: return %[[res]]
func.func @nofoldShrs3() -> i64 {
%c1 = arith.constant 8 : i64
%c64 = arith.constant 64 : i64
// Note that this should return Poison in the future.
%r = arith.shrsi %c1, %c64 : i64
return %r : i64
}

// -----

// CHECK-LABEL: @test_negf(
Expand Down

0 comments on commit 1a8c613

Please sign in to comment.