@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136
136
matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
137
137
ConversionPatternRewriter &rewriter) const override {
138
138
Location loc = op->getLoc ();
139
+ Value initShflValue = adaptor.getValue ();
140
+ Type shflType = initShflValue.getType ();
139
141
// TODO: Add support for non 32-bit shuffle values.
140
- if (adaptor.getValue ().getType ().getIntOrFloatBitWidth () != 32 )
141
- return failure ();
142
+ if (!shflType.isIntOrFloat () || shflType.getIntOrFloatBitWidth () != 32 )
143
+ return rewriter.notifyMatchFailure (
144
+ op, " only 32-bit int/float types are supported" );
145
+
142
146
const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
143
147
Value srcLaneId = getLaneId (rewriter, loc, indexBitwidth);
144
148
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
175
179
Value two = rewriter.create <LLVM::ConstantOp>(loc, int32Type, 2 );
176
180
Value dwordAlignedDstLane =
177
181
rewriter.create <LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
178
- Value initShflValue = adaptor.getValue ();
179
- if (adaptor.getValue ().getType ().isF32 ()) {
182
+ if (shflType.isF32 ()) {
180
183
initShflValue =
181
184
rewriter.create <LLVM::BitcastOp>(loc, int32Type, initShflValue);
182
185
}
183
186
Value shflValue = rewriter.create <ROCDL::DsBpermuteOp>(
184
187
loc, int32Type, dwordAlignedDstLane, initShflValue);
185
- if (adaptor.getValue ().getType ().isF32 ()) {
186
- shflValue = rewriter.create <LLVM::BitcastOp>(
187
- loc, adaptor.getValue ().getType (), shflValue);
188
+ if (shflType.isF32 ()) {
189
+ shflValue = rewriter.create <LLVM::BitcastOp>(loc, shflType, shflValue);
188
190
}
189
191
rewriter.replaceOp (op, {shflValue, isActiveSrcLane});
190
192
return success ();
0 commit comments