Skip to content

Commit 89d1029

Browse files
authored
vulkan : add fp16 support for the conv_2d kernel (ggml-org#14872)
* add f16 to conv_2d testing * weaken conv2d test error threshold
1 parent f1a4e72 commit 89d1029

File tree

3 files changed

+49
-20
lines changed

3 files changed

+49
-20
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ struct vk_device_struct {
484484
vk_pipeline pipeline_rwkv_wkv7_f32;
485485
vk_pipeline pipeline_opt_step_adamw_f32;
486486
vk_pipeline pipeline_conv2d_f32;
487+
vk_pipeline pipeline_conv2d_f16_f32;
487488
vk_pipeline pipeline_conv2d_dw_whcn_f32;
488489
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489490

@@ -3074,12 +3075,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
30743075
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30753076
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30763077
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3078+
ggml_vk_create_pipeline(
3079+
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3080+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
30773082
} else {
30783083
ggml_vk_create_pipeline(
30793084
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30803085
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30813086
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
30823087
false);
3088+
ggml_vk_create_pipeline(
3089+
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3090+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3091+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3092+
false);
30833093
}
30843094

30853095
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6958,9 +6968,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69586968
}
69596969
return nullptr;
69606970
case GGML_OP_CONV_2D:
6961-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6971+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
69626972
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963-
return ctx->device->pipeline_conv2d_f32;
6973+
if (src0->type == GGML_TYPE_F32) {
6974+
return ctx->device->pipeline_conv2d_f32;
6975+
} else if (src0->type == GGML_TYPE_F16) {
6976+
return ctx->device->pipeline_conv2d_f16_f32;
6977+
}
69646978
}
69656979
return nullptr;
69666980
case GGML_OP_CONV_2D_DW:
@@ -8185,13 +8199,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81858199

81868200
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
81878201
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8188-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
8202+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
81898203
GGML_ASSERT(src1->type == GGML_TYPE_F32);
81908204
GGML_ASSERT(dst->type == GGML_TYPE_F32);
81918205

81928206
GGML_TENSOR_BINARY_OP_LOCALS
81938207

8194-
GGML_ASSERT(nb00 == sizeof(float));
8208+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
81958209
GGML_ASSERT(nb10 == sizeof(float));
81968210
GGML_ASSERT(nb0 == sizeof(float));
81978211

@@ -10874,7 +10888,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1087410888
const vk_device& device = ggml_vk_get_device(ctx->device);
1087510889
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
1087610890
// Channel-contiguous format is not supported yet.
10877-
return (op->src[0]->type == GGML_TYPE_F32 &&
10891+
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1087810892
op->src[1]->type == GGML_TYPE_F32 &&
1087910893
op->type == GGML_TYPE_F32 &&
1088010894
ggml_is_contiguous(op->src[0]) &&

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ void process_shaders() {
656656
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657657

658658
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659660

660661
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
661662
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

tests/test-backend-ops.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3734,6 +3734,7 @@ struct test_im2col : public test_case {
37343734
struct test_conv_2d : public test_case {
37353735
const std::array<int64_t, 4> ne_input;
37363736
const std::array<int64_t, 4> ne_kernel;
3737+
const ggml_type type_kernel;
37373738
const int stride0;
37383739
const int stride1;
37393740
const int padding0;
@@ -3751,7 +3752,11 @@ struct test_conv_2d : public test_case {
37513752
// IM2COL -> MUL_MM graph will be built.
37523753

37533754
std::string vars() override {
3754-
return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3755+
return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3756+
}
3757+
3758+
double max_nmse_err() override {
3759+
return 5e-4;
37553760
}
37563761

37573762
uint64_t op_flops(ggml_tensor * t) override {
@@ -3782,10 +3787,11 @@ struct test_conv_2d : public test_case {
37823787
}
37833788

37843789
test_conv_2d(std::array<int64_t, 4> ne_input = { 64, 64, 16, 1 },
3785-
std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, int stride0 = 1, int stride1 = 1, int padding0 = 0,
3786-
int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
3790+
std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1,
3791+
int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
37873792
ne_input(ne_input),
37883793
ne_kernel(ne_kernel),
3794+
type_kernel(type_kernel),
37893795
stride0(stride0),
37903796
stride1(stride1),
37913797
padding0(padding0),
@@ -3798,7 +3804,7 @@ struct test_conv_2d : public test_case {
37983804
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
37993805
ggml_set_name(input, "input");
38003806

3801-
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
3807+
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
38023808
ggml_set_name(kernel, "kernel");
38033809

38043810
if (cwhn) {
@@ -5165,10 +5171,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51655171
{ 16, 3, 256, 128, 8 }
51665172
};
51675173

5168-
for (auto act_case : cases) {
5169-
test_cases.emplace_back(new test_conv_2d(
5170-
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5171-
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false));
5174+
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5175+
for (auto act_case : cases) {
5176+
test_cases.emplace_back(new test_conv_2d(
5177+
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5178+
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5179+
kernel_type, 1, 1, 0, 0, 1, 1, false));
5180+
}
51725181
}
51735182
#endif
51745183

@@ -5194,8 +5203,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51945203
for (uint32_t W : { 1, 141 }) {
51955204
if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 &&
51965205
calc_conv_output_size(H, KH, s1, p1, d1) > 0) {
5197-
test_cases.emplace_back(new test_conv_2d(
5198-
{ W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false));
5206+
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5207+
test_cases.emplace_back(new test_conv_2d(
5208+
{ W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false));
5209+
}
51995210
}
52005211
}
52015212
}
@@ -5840,11 +5851,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
58405851
{ 16, 3, 512, 128, 8 },
58415852
};
58425853

5843-
for (auto act_case : cases) {
5844-
// Direct CONV_2D
5845-
test_cases.emplace_back(new test_conv_2d(
5846-
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5847-
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false));
5854+
for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
5855+
for (auto act_case : cases) {
5856+
// Direct CONV_2D
5857+
test_cases.emplace_back(new test_conv_2d(
5858+
{ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5859+
{ act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5860+
kernel_type, 1, 1, 0, 0, 1, 1, false));
5861+
}
58485862
}
58495863

58505864
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));

0 commit comments

Comments
 (0)