Skip to content

Commit d893d12

Browse files
authored
[mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (#135504)
Calling `getIntOrFloatBitWidth` on non-int/float types (`gpu.shuffle` also accepts vectors) will crash.
1 parent c0afb77 commit d893d12

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136136
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
137137
ConversionPatternRewriter &rewriter) const override {
138138
Location loc = op->getLoc();
139+
Value initShflValue = adaptor.getValue();
140+
Type shflType = initShflValue.getType();
139141
// TODO: Add support for non 32-bit shuffle values.
140-
if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
141-
return failure();
142+
if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
143+
return rewriter.notifyMatchFailure(
144+
op, "only 32-bit int/float types are supported");
145+
142146
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
143147
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
144148

@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
175179
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
176180
Value dwordAlignedDstLane =
177181
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
178-
Value initShflValue = adaptor.getValue();
179-
if (adaptor.getValue().getType().isF32()) {
182+
if (shflType.isF32()) {
180183
initShflValue =
181184
rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
182185
}
183186
Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
184187
loc, int32Type, dwordAlignedDstLane, initShflValue);
185-
if (adaptor.getValue().getType().isF32()) {
186-
shflValue = rewriter.create<LLVM::BitcastOp>(
187-
loc, adaptor.getValue().getType(), shflValue);
188+
if (shflType.isF32()) {
189+
shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
188190
}
189191
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
190192
return success();

mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
4040
auto i64 = rewriter.getI64Type();
4141

4242
// If the type of the value is either i32 or f32, the op is already valid.
43-
if (valueType.getIntOrFloatBitWidth() == 32)
44-
return failure();
43+
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
44+
return rewriter.notifyMatchFailure(
45+
op, "only 64-bit int/float types are supported");
4546

4647
Value lo, hi;
4748

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
2+
3+
gpu.module @test_module {
4+
// ROCDL lowering only suport shuffles for 32bit ints/floats, but they
5+
// shouldn't crash on unsupported types.
6+
func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
7+
%offset = arith.constant 4 : i32
8+
%width = arith.constant 64 : i32
9+
// expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
10+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
11+
return %shfl : vector<4xf16>
12+
}
13+
}

mlir/test/Dialect/GPU/shuffle-rewrite.mlir

+11
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,14 @@ module {
4949
return
5050
}
5151
}
52+
53+
// -----
54+
55+
// CHECK-LABEL: @gpu_shuffle_unsupported
56+
func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
57+
%offset = arith.constant 4 : i32
58+
%width = arith.constant 64 : i32
59+
// CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16>
60+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
61+
return %shfl : vector<4xf16>
62+
}

0 commit comments

Comments
 (0)