@@ -257,3 +257,68 @@ hal.executable private @shared_memory_copy {
257
257
// CHECK: vector.transfer_read %[[ALLOC]]{{.*}} : memref<32xf32, #gpu.address_space<workgroup>>, vector<1xf32>
258
258
// CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32>
259
259
// CHECK: return
260
+
261
+
262
+ // -----
263
+
264
+ // Check that we multi-row matvec gets distributed across subgoroup threads.
265
+
266
+ #executable_target_rocm_hsaco_fb = #hal.executable.target <" rocm" , " rocm-hsaco-fb" , {target_arch = " gfx940" }>
267
+ #pipeline_layout = #hal.pipeline.layout <push_constants = 0 , sets = [
268
+ #hal.descriptor_set.layout <0 , bindings = [
269
+ #hal.descriptor_set.binding <0 , storage_buffer >,
270
+ #hal.descriptor_set.binding <1 , storage_buffer >,
271
+ #hal.descriptor_set.binding <2 , storage_buffer >
272
+ ]>
273
+ ]>
274
+ hal.executable private @multirow {
275
+ hal.executable.variant @rocm target (#executable_target_rocm_hsaco_fb ) {
276
+ hal.executable.export @multirow layout (#pipeline_layout ) attributes {
277
+ workgroup_size = [64 : index , 1 : index , 1 : index ]
278
+ }
279
+ builtin.module {
280
+ func.func @multirow () {
281
+ %cst = arith.constant dense <0.000000e+00 > : vector <4 x512 xf16 >
282
+ %c0 = arith.constant 0 : index
283
+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <1 x4 xf16 >
284
+ %c4096 = arith.constant 4096 : index
285
+ %c512 = arith.constant 512 : index
286
+ %cst_1 = arith.constant 0.000000e+00 : f16
287
+ %id = gpu.thread_id x
288
+ %0 = hal.interface.binding.subspan set (0 ) binding (0 ) type (storage_buffer ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : memref <1 x4096 xf16 , #hal.descriptor_type <storage_buffer >>
289
+ memref.assume_alignment %0 , 64 : memref <1 x4096 xf16 , #hal.descriptor_type <storage_buffer >>
290
+ %1 = hal.interface.binding.subspan set (0 ) binding (1 ) type (storage_buffer ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : memref <32000 x4096 xf16 , #hal.descriptor_type <storage_buffer >>
291
+ memref.assume_alignment %1 , 64 : memref <32000 x4096 xf16 , #hal.descriptor_type <storage_buffer >>
292
+ %2 = hal.interface.binding.subspan set (0 ) binding (2 ) type (storage_buffer ) alignment (64 ) offset (%c0 ) : memref <1 x32000 xf16 , #hal.descriptor_type <storage_buffer >>
293
+ memref.assume_alignment %2 , 64 : memref <1 x32000 xf16 , #hal.descriptor_type <storage_buffer >>
294
+ %workgroup_id_x = hal.interface.workgroup.id [0 ] : index
295
+ %3 = affine.apply affine_map <()[s0 ] -> (s0 * 4 )>()[%workgroup_id_x ]
296
+ %4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args (%arg1 = %cst ) -> (vector <4 x512 xf16 >) {
297
+ %8 = vector.transfer_read %0 [%c0 , %arg0 ], %cst_1 {in_bounds = [true , true ], permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 )>} : memref <1 x4096 xf16 , #hal.descriptor_type <storage_buffer >>, vector <4 x512 xf16 >
298
+ %9 = vector.transfer_read %1 [%3 , %arg0 ], %cst_1 {in_bounds = [true , true ]} : memref <32000 x4096 xf16 , #hal.descriptor_type <storage_buffer >>, vector <4 x512 xf16 >
299
+ %10 = arith.mulf %8 , %9 : vector <4 x512 xf16 >
300
+ %11 = arith.addf %arg1 , %10 : vector <4 x512 xf16 >
301
+ scf.yield %11 : vector <4 x512 xf16 >
302
+ }
303
+ %5 = vector.broadcast %4 : vector <4 x512 xf16 > to vector <1 x4 x512 xf16 >
304
+ %6 = vector.multi_reduction <add >, %5 , %cst_0 [2 ] : vector <1 x4 x512 xf16 > to vector <1 x4 xf16 >
305
+ %7 = vector.extract %6 [0 ] : vector <4 xf16 > from vector <1 x4 xf16 >
306
+ vector.transfer_write %7 , %2 [%c0 , %3 ] {in_bounds = [true ]} : vector <4 xf16 >, memref <1 x32000 xf16 , #hal.descriptor_type <storage_buffer >>
307
+ return
308
+ }
309
+ }
310
+ }
311
+ }
312
+
313
+ // CHECK-LABEL: func.func @multirow() {
314
+ // CHECK: scf.for {{.*}} -> (vector<4x8xf16>) {
315
+ // CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
316
+ // CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
317
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16>
318
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16>
319
+ // CHECK: }
320
+ // CHECK: gpu.shuffle xor
321
+ // CHECK: scf.if {{.*}} {
322
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
323
+ // CHECK: }
324
+ // CHECK-NEXT: return
0 commit comments