Skip to content

Commit 3d506e5

Browse files
vulkan: fix COUNT_EQUAL memset using a fillBuffer command
1 parent 941efc0 commit 3d506e5

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -3730,6 +3730,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
37303730
}
37313731
}
37323732

3733+
static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3734+
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
3735+
3736+
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
3737+
}
3738+
37333739
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
37343740
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
37353741

@@ -5717,6 +5723,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57175723
// im2col uses only src1 and dst buffers
57185724
ggml_vk_sync_buffers(subctx);
57195725
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5726+
} else if (op == GGML_OP_COUNT_EQUAL) {
5727+
ggml_vk_sync_buffers(subctx);
5728+
// count_equal assumes that destination buffer is initialized with zeroes
5729+
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
5730+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
57205731
} else if (use_src2) {
57215732
ggml_vk_sync_buffers(subctx);
57225733
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -6331,7 +6342,6 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
63316342
}
63326343

63336344
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6334-
ggml_backend_tensor_memset(dst, 0, 0, ggml_nbytes(dst));
63356345
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
63366346
}
63376347

0 commit comments

Comments
 (0)