Skip to content

Commit e172d2d

Browse files
vulkan: fix COUNT_EQUAL memset using a fillBuffer command
1 parent e22079c commit e172d2d

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -3796,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
37963796
}
37973797
}
37983798

3799+
static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3800+
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
3801+
3802+
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
3803+
}
3804+
37993805
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
38003806
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
38013807

@@ -5800,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
58005806
// im2col uses only src1 and dst buffers
58015807
ggml_vk_sync_buffers(subctx);
58025808
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5809+
} else if (op == GGML_OP_COUNT_EQUAL) {
5810+
ggml_vk_sync_buffers(subctx);
5811+
// count_equal assumes that destination buffer is initialized with zeroes
5812+
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
5813+
ggml_vk_sync_buffers(subctx);
5814+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
58035815
} else if (use_src2) {
58045816
ggml_vk_sync_buffers(subctx);
58055817
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -6422,7 +6434,6 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
64226434
}
64236435

64246436
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6425-
ggml_backend_tensor_memset(dst, 0, 0, ggml_nbytes(dst));
64266437
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
64276438
}
64286439

0 commit comments

Comments
 (0)