@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
136136  matchAndRewrite (gpu::ShuffleOp op, OpAdaptor adaptor,
137137                  ConversionPatternRewriter &rewriter) const  override  {
138138    Location loc = op->getLoc ();
139+     Value initShflValue = adaptor.getValue ();
140+     Type shflType = initShflValue.getType ();
139141    //  TODO: Add support for non 32-bit shuffle values.
140-     if  (adaptor.getValue ().getType ().getIntOrFloatBitWidth () != 32 )
141-       return  failure ();
142+     if  (!shflType.isIntOrFloat () || shflType.getIntOrFloatBitWidth () != 32 )
143+       return  rewriter.notifyMatchFailure (
144+           op, " only 32-bit int/float types are supported" 
145+ 
142146    const  unsigned  indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
143147    Value srcLaneId = getLaneId (rewriter, loc, indexBitwidth);
144148
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
175179    Value two = rewriter.create <LLVM::ConstantOp>(loc, int32Type, 2 );
176180    Value dwordAlignedDstLane =
177181        rewriter.create <LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
178-     Value initShflValue = adaptor.getValue ();
179-     if  (adaptor.getValue ().getType ().isF32 ()) {
182+     if  (shflType.isF32 ()) {
180183      initShflValue =
181184          rewriter.create <LLVM::BitcastOp>(loc, int32Type, initShflValue);
182185    }
183186    Value shflValue = rewriter.create <ROCDL::DsBpermuteOp>(
184187        loc, int32Type, dwordAlignedDstLane, initShflValue);
185-     if  (adaptor.getValue ().getType ().isF32 ()) {
186-       shflValue = rewriter.create <LLVM::BitcastOp>(
187-           loc, adaptor.getValue ().getType (), shflValue);
188+     if  (shflType.isF32 ()) {
189+       shflValue = rewriter.create <LLVM::BitcastOp>(loc, shflType, shflValue);
188190    }
189191    rewriter.replaceOp (op, {shflValue, isActiveSrcLane});
190192    return  success ();
0 commit comments