Skip to content

Commit 3de29e0

Browse files
kuhargithub-actions[bot]
authored andcommitted
[LLVMGPU] Add multi-row vector reduction configuration (#73)
This is to speed up matvec. The new configuration is experimental and only applied on ROCm targets.
1 parent 89add3d commit 3de29e0

File tree

4 files changed

+134
-1
lines changed

4 files changed

+134
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class VectorReductionToGPUPass
195195
bool expandSubgroupReduction,
196196
std::function<int(func::FuncOp)> getWarpSize)
197197
: expandSubgroupReduction(expandSubgroupReduction),
198-
getWarpSize(getWarpSize) {}
198+
getWarpSize(std::move(getWarpSize)) {}
199199

200200
void getDependentDialects(DialectRegistry &registry) const override {
201201
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,

compiler/src/iree/compiler/Codegen/Common/GPU/test/vector_reduction_to_gpu.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,68 @@ hal.executable private @shared_memory_copy {
257257
// CHECK: vector.transfer_read %[[ALLOC]]{{.*}} : memref<32xf32, #gpu.address_space<workgroup>>, vector<1xf32>
258258
// CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32>
259259
// CHECK: return
260+
261+
262+
// -----
263+
264+
// Check that we multi-row matvec gets distributed across subgoroup threads.
265+
266+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>
267+
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
268+
#hal.descriptor_set.layout<0, bindings = [
269+
#hal.descriptor_set.binding<0, storage_buffer>,
270+
#hal.descriptor_set.binding<1, storage_buffer>,
271+
#hal.descriptor_set.binding<2, storage_buffer>
272+
]>
273+
]>
274+
hal.executable private @multirow {
275+
hal.executable.variant @rocm target(#executable_target_rocm_hsaco_fb) {
276+
hal.executable.export @multirow layout(#pipeline_layout) attributes {
277+
workgroup_size = [64 : index, 1 : index, 1 : index]
278+
}
279+
builtin.module {
280+
func.func @multirow() {
281+
%cst = arith.constant dense<0.000000e+00> : vector<4x512xf16>
282+
%c0 = arith.constant 0 : index
283+
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf16>
284+
%c4096 = arith.constant 4096 : index
285+
%c512 = arith.constant 512 : index
286+
%cst_1 = arith.constant 0.000000e+00 : f16
287+
%id = gpu.thread_id x
288+
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>
289+
memref.assume_alignment %0, 64 : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>
290+
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>
291+
memref.assume_alignment %1, 64 : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>
292+
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
293+
memref.assume_alignment %2, 64 : memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
294+
%workgroup_id_x = hal.interface.workgroup.id[0] : index
295+
%3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
296+
%4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst) -> (vector<4x512xf16>) {
297+
%8 = vector.transfer_read %0[%c0, %arg0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x512xf16>
298+
%9 = vector.transfer_read %1[%3, %arg0], %cst_1 {in_bounds = [true, true]} : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x512xf16>
299+
%10 = arith.mulf %8, %9 : vector<4x512xf16>
300+
%11 = arith.addf %arg1, %10 : vector<4x512xf16>
301+
scf.yield %11 : vector<4x512xf16>
302+
}
303+
%5 = vector.broadcast %4 : vector<4x512xf16> to vector<1x4x512xf16>
304+
%6 = vector.multi_reduction <add>, %5, %cst_0 [2] : vector<1x4x512xf16> to vector<1x4xf16>
305+
%7 = vector.extract %6[0] : vector<4xf16> from vector<1x4xf16>
306+
vector.transfer_write %7, %2[%c0, %3] {in_bounds = [true]} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
307+
return
308+
}
309+
}
310+
}
311+
}
312+
313+
// CHECK-LABEL: func.func @multirow() {
314+
// CHECK: scf.for {{.*}} -> (vector<4x8xf16>) {
315+
// CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
316+
// CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
317+
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16>
318+
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16>
319+
// CHECK: }
320+
// CHECK: gpu.shuffle xor
321+
// CHECK: scf.if {{.*}} {
322+
// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
323+
// CHECK: }
324+
// CHECK-NEXT: return

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
#include "llvm/Support/Debug.h"
2323
#include "mlir/Analysis/SliceAnalysis.h"
2424
#include "mlir/Dialect/Arith/IR/Arith.h"
25+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2526
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
27+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2628
#include "mlir/IR/Matchers.h"
2729
#include "mlir/IR/Types.h"
2830
#include "mlir/IR/Value.h"
@@ -924,6 +926,25 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
924926
if ((groupSize / subgroupSize) > subgroupSize)
925927
return failure();
926928

929+
// With just one subgroup per workgroup, make each subgroup do more work and
930+
// process a few reductions along the last parallel dimension.
931+
// TODO: We should also check that this will result in data reuse for at least
932+
// one argument.
933+
// TODO: This is experimental for matvec (matmul_transpose_b) on rocm-only for
934+
// now.
935+
if (numDynamicReductionDims == 0 && numParallelDims == 2 &&
936+
isRocmTarget(entryPoint)) {
937+
if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) {
938+
int maxParallelFactor = 4; // Keeping this conservative for now.
939+
int64_t lastParallelBound = bounds[parallelDims.back()];
940+
if (!ShapedType::isDynamic(lastParallelBound) &&
941+
(lastParallelBound % maxParallelFactor == 0) &&
942+
lastParallelBound > maxParallelFactor) {
943+
workgroupTileSizes.back() = maxParallelFactor;
944+
}
945+
}
946+
}
947+
927948
std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
928949
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
929950
int64_t remainingGroupSize = groupSize;

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,50 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf
5050
// CHECK: func.func @dynamic_batch_matvec()
5151
// CHECK: linalg.batch_matmul
5252
// CHECK-SAME: lowering_config = #[[$CONFIG]]
53+
54+
// -----
55+
56+
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
57+
#hal.descriptor_set.layout<0, bindings = [
58+
#hal.descriptor_set.binding<0, storage_buffer>,
59+
#hal.descriptor_set.binding<1, storage_buffer>,
60+
#hal.descriptor_set.binding<2, storage_buffer>
61+
]>
62+
]>
63+
64+
hal.executable @vmt {
65+
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) {
66+
hal.executable.export @vmt layout(#pipeline_layout)
67+
builtin.module {
68+
func.func @vmt() {
69+
%c0 = arith.constant 0 : index
70+
%cst = arith.constant 0.000000e+00 : f16
71+
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>>
72+
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
73+
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
74+
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>> -> tensor<1x4096xf16>
75+
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
76+
%5 = tensor.empty() : tensor<1x32000xf16>
77+
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16>
78+
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) {
79+
^bb0(%in: f16, %in_0: f16, %out: f16):
80+
%8 = arith.mulf %in, %in_0 : f16
81+
%9 = arith.addf %out, %8 : f16
82+
linalg.yield %9 : f16
83+
} -> tensor<1x32000xf16>
84+
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
85+
return
86+
}
87+
}
88+
}
89+
}
90+
91+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 4], [0, 0, 512]{{\]}}>
92+
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUWarpReduction>
93+
// CHECK-LABEL: hal.executable.export public @vmt
94+
// CHECK-SAME: subgroup_size = 64 : index
95+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
96+
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
97+
// CHECK: func.func @vmt()
98+
// CHECK: linalg.generic
99+
// CHECK-SAME: lowering_config = #[[$CONFIG]]

0 commit comments

Comments
 (0)