Skip to content

Commit 84712b6

Browse files
authored
vulkan: fix rms_norm_mul to handle broadcasting dim0 (ggml-org#14817)
1 parent d4d1522 commit 84712b6

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10248,7 +10248,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
1024810248
}
1024910249
// if rms_norm is the B operand, then we don't handle broadcast
1025010250
if (rms_norm == mul->src[1] &&
10251-
mul->src[0]->ne[1] != rms_norm->ne[1]) {
10251+
!ggml_are_same_shape(mul->src[0], rms_norm)) {
1025210252
return false;
1025310253
}
1025410254
// rms_norm shader assumes contiguous rows

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@ void main() {
5050
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
5151

5252
if (do_multiply) {
53-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
54-
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
53+
if (ncols > p.ne10) {
54+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
55+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
56+
}
57+
} else {
58+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
59+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
60+
}
5561
}
5662
} else {
5763
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {

0 commit comments

Comments
 (0)