Skip to content

Commit 0224591

Browse files
remyoudomphengorca-zhang
authored andcommitted
vulkan: implement several ops relevant for ggml_opt (ggml-org#11769)
* vulkan: support memset_tensor * vulkan: support GGML_OP_SUM * vulkan: implement GGML_OP_ARGMAX * vulkan: implement GGML_OP_SUB * vulkan: implement GGML_OP_COUNT_EQUAL * vulkan: implement GGML_OP_OPT_STEP_ADAMW * vulkan: fix check_results RWKV_WKV6 crash and memory leaks * vulkan: implement GGML_OP_REPEAT_BACK * tests: remove invalid test-backend-ops REPEAT_BACK tests * vulkan: fix COUNT_EQUAL memset using a fillBuffer command
1 parent 0775552 commit 0224591

File tree

8 files changed

+568
-222
lines changed

8 files changed

+568
-222
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+366-217
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14+
15+
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
16+
shared uint tmp[BLOCK_SIZE];
17+
18+
void main() {
19+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
20+
const uint col = gl_LocalInvocationID.x;
21+
22+
if (col >= p.KX) {
23+
return;
24+
}
25+
A_TYPE amax = data_a[row*p.KX + col];
26+
tmp[col] = col;
27+
28+
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
29+
A_TYPE val = data_a[row*p.KX + i];
30+
if (val > amax) {
31+
amax = val;
32+
tmp[col] = i;
33+
}
34+
}
35+
tmpmax[col] = amax;
36+
37+
barrier();
38+
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
39+
if (col < s && col + s < p.KX) {
40+
if (tmpmax[col] < tmpmax[col + s]) {
41+
tmpmax[col] = tmpmax[col + s];
42+
tmp[col] = tmp[col + s];
43+
}
44+
}
45+
barrier();
46+
}
47+
48+
if (col == 0) {
49+
data_d[row] = D_TYPE(tmp[0]);
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#include "types.comp"
6+
#include "generic_head.comp"
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
12+
layout (binding = 2) buffer D {D_TYPE data_d[];};
13+
14+
const uint CHUNK_SIZE = 512;
15+
16+
void main() {
17+
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
18+
const uint col = gl_LocalInvocationID.x;
19+
20+
uint count = 0;
21+
[[unroll]]
22+
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
23+
const uint idx = base + i + col;
24+
if (idx >= p.KX) {
25+
break;
26+
}
27+
count += uint(data_a[idx] == data_b[idx]);
28+
}
29+
30+
atomicAdd(data_d[0], D_TYPE(count));
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) buffer X {A_TYPE x[];};
11+
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
12+
layout (binding = 2) buffer GM {A_TYPE gradm[];};
13+
layout (binding = 3) buffer GV {A_TYPE gradv[];};
14+
layout (binding = 4) readonly buffer P {float params[7];};
15+
16+
void main() {
17+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
18+
19+
if (i >= p.KX) {
20+
return;
21+
}
22+
23+
const float alpha = params[0];
24+
const float beta1 = params[1];
25+
const float beta2 = params[2];
26+
const float eps = params[3];
27+
const float wd = params[4];
28+
const float beta1h = params[5];
29+
const float beta2h = params[6];
30+
31+
const float gi = grad[i];
32+
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
33+
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
34+
35+
gradm[i] = gmi;
36+
gradv[i] = gvi;
37+
38+
const float mh = gmi*beta1h;
39+
const float vh = sqrt(gvi*beta2h) + eps;
40+
41+
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
// Destination multi-index (inlined dst_idx)
16+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
17+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
18+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
19+
const uint i12_offset = i12*p.ne11*p.ne10;
20+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
21+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
22+
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
23+
24+
// Accumulate from sources
25+
A_TYPE acc = A_TYPE(0);
26+
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
27+
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
28+
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
29+
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
30+
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
31+
}
32+
}
33+
}
34+
}
35+
36+
data_d[get_doffset() + d_idx] = D_TYPE(acc);
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
5+
#include "types.comp"
6+
#include "generic_binary_head.comp"
7+
8+
const uint num_threads = 256;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
12+
void main() {
13+
uint idx = get_idx();
14+
15+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16+
const uint num_iter = 2;
17+
18+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
19+
if (idx >= p.ne) {
20+
continue;
21+
}
22+
uint i00, i01, i02, i03;
23+
get_indices(idx, i00, i01, i02, i03);
24+
25+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
26+
27+
idx += num_threads;
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ void process_shaders() {
443443
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
444444
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
445445

446+
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
447+
446448
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
447449

448450
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -452,6 +454,7 @@ void process_shaders() {
452454
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
453455

454456
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
457+
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
455458

456459
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
457460

@@ -501,7 +504,9 @@ void process_shaders() {
501504

502505
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
503506

507+
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
504508
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
509+
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
505510

506511
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
507512
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
@@ -513,6 +518,8 @@ void process_shaders() {
513518

514519
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
515520

521+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
522+
516523
for (auto &c : compiles) {
517524
c.wait();
518525
}

tests/test-backend-ops.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
12541254
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
12551255
ggml_set_name(b, "b");
12561256

1257-
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
1257+
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
12581258
ggml_set_name(b_argmax, "b_argmax");
12591259

12601260
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
15111511
};
15121512

15131513
// GGML_OP_ADD
1514+
// GGML_OP_SUB
15141515
// GGML_OP_MUL
15151516
// GGML_OP_DIV
15161517
struct test_bin_bcast : public test_case {
@@ -3860,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38603861
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
38613862
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
38623863

3863-
test_cases.emplace_back(new test_count_equal());
3864+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
3865+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
38643866

38653867
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
38663868
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
@@ -3885,8 +3887,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38853887
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
38863888
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
38873889
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
3888-
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
3889-
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
38903890
}
38913891

38923892
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
@@ -3938,7 +3938,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39383938
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
39393939

39403940
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
3941-
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
3941+
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
39423942
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
39433943
}
39443944
};

0 commit comments

Comments
 (0)