Skip to content

Commit 2cf6f8b

Browse files
CPU/CUDA: fix GQA mul mat back, add CUDA support
1 parent 564804b commit 2cf6f8b

File tree

6 files changed

+167
-64
lines changed

6 files changed

+167
-64
lines changed

ggml/include/ggml.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,8 @@ extern "C" {
933933
GGML_API struct ggml_tensor * ggml_repeat_back(
934934
struct ggml_context * ctx,
935935
struct ggml_tensor * a,
936-
struct ggml_tensor * b);
936+
struct ggml_tensor * b,
937+
bool adjacent); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
937938

938939
// concat a and b along dim
939940
// used in stable-diffusion

ggml/src/ggml-cpu/ggml-cpu.c

+40-18
Original file line numberDiff line numberDiff line change
@@ -6046,6 +6046,8 @@ static void ggml_compute_forward_repeat_back_f32(
60466046
GGML_ASSERT(nb0 == sizeof(float));
60476047
GGML_ASSERT(nb00 == sizeof(float));
60486048

6049+
const bool adjacent = dst->op_params[0] != 0;
6050+
60496051
if (ggml_is_contiguous(dst)) {
60506052
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
60516053
} else {
@@ -6060,22 +6062,42 @@ static void ggml_compute_forward_repeat_back_f32(
60606062
}
60616063
}
60626064

6063-
// TODO: maybe this is not optimal?
6064-
for (int i3 = 0; i3 < nr3; i3++) {
6065-
for (int k3 = 0; k3 < ne3; k3++) {
6066-
for (int i2 = 0; i2 < nr2; i2++) {
6067-
for (int k2 = 0; k2 < ne2; k2++) {
6068-
for (int i1 = 0; i1 < nr1; i1++) {
6069-
for (int k1 = 0; k1 < ne1; k1++) {
6070-
for (int i0 = 0; i0 < nr0; i0++) {
6071-
ggml_vec_acc_f32(ne0,
6072-
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
6073-
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
6074-
}
6075-
}
6076-
}
6077-
}
6078-
}
6065+
if (adjacent) {
6066+
for (int i3 = 0; i3 < nr3; i3++) {
6067+
for (int k3 = 0; k3 < ne3; k3++) {
6068+
for (int i2 = 0; i2 < nr2; i2++) {
6069+
for (int k2 = 0; k2 < ne2; k2++) {
6070+
for (int i1 = 0; i1 < nr1; i1++) {
6071+
for (int k1 = 0; k1 < ne1; k1++) {
6072+
for (int i0 = 0; i0 < nr0; i0++) {
6073+
ggml_vec_acc_f32(ne0,
6074+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
6075+
(float *) ((char *) src0->data + (k3*nr3 + i3)*nb03 + (k2*nr2 + i2)*nb02 + (k1*nr1 + i1)*nb01 + (i0*ne0)*nb00));
6076+
}
6077+
}
6078+
}
6079+
}
6080+
}
6081+
}
6082+
}
6083+
} else {
6084+
// TODO: maybe this is not optimal?
6085+
for (int i3 = 0; i3 < nr3; i3++) {
6086+
for (int k3 = 0; k3 < ne3; k3++) {
6087+
for (int i2 = 0; i2 < nr2; i2++) {
6088+
for (int k2 = 0; k2 < ne2; k2++) {
6089+
for (int i1 = 0; i1 < nr1; i1++) {
6090+
for (int k1 = 0; k1 < ne1; k1++) {
6091+
for (int i0 = 0; i0 < nr0; i0++) {
6092+
ggml_vec_acc_f32(ne0,
6093+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
6094+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
6095+
}
6096+
}
6097+
}
6098+
}
6099+
}
6100+
}
60796101
}
60806102
}
60816103
}
@@ -7883,7 +7905,7 @@ static void ggml_compute_forward_out_prod_f32(
78837905

78847906
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78857907
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7886-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7908+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78877909

78887910
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
78897911
}
@@ -7892,7 +7914,7 @@ static void ggml_compute_forward_out_prod_f32(
78927914

78937915
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78947916
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7895-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7917+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78967918

78977919
ggml_vec_mad_f32(ne0, d, s0, *s1);
78987920
}

ggml/src/ggml-cpu/ggml-cpu.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416416
case GGML_OP_IM2COL_BACK:
417417
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
418418
case GGML_OP_OUT_PROD:
419-
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
419+
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
420420
default:
421421
return true;
422422
}

ggml/src/ggml-cuda/binbcast.cu

+27-8
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,14 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
9191
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
9292
}
9393

94-
template <typename T>
94+
template <bool adjacent, typename T>
9595
static __global__ void k_repeat_back(
9696
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
9797
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
9898

99+
const int64_t nr1 = ne01 / ne1;
100+
const int64_t nr2 = ne02 / ne2;
101+
99102
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
100103
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
101104
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
@@ -105,10 +108,20 @@ static __global__ void k_repeat_back(
105108
}
106109

107110
T sum = 0;
108-
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109-
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110-
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111-
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
111+
if (adjacent) {
112+
for (int64_t i2 = tid2*nr2; i2 < (tid2 + 1)*nr2; ++i2) {
113+
for (int64_t i1 = tid1*nr1; i1 < (tid1 + 1)*nr1; ++i1) {
114+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
115+
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
116+
}
117+
}
118+
}
119+
} else {
120+
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
121+
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
122+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
123+
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
124+
}
112125
}
113126
}
114127
}
@@ -275,11 +288,15 @@ struct bin_bcast_cuda {
275288
template <typename T>
276289
static void repeat_back_cuda(
277290
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278-
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
291+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const bool adjacent, cudaStream_t stream) {
279292

280293
const dim3 block_dims(WARP_SIZE, 1, 1);
281294
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
282-
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
295+
if (adjacent) {
296+
k_repeat_back<true, T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
297+
} else {
298+
k_repeat_back<false, T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
299+
}
283300
}
284301

285302
template<class op>
@@ -342,11 +359,13 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
342359
const int64_t ne2 = dst->ne[2];
343360
GGML_ASSERT(dst->ne[3] == 1);
344361

362+
const bool adjacent = dst->op_params[0] != 0;
363+
345364
switch (dst->type) {
346365
case GGML_TYPE_F32: {
347366
const float * src0_d = (const float *) src0->data;
348367
float * dst_d = (float *) dst->data;
349-
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
368+
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, adjacent, stream);
350369
} break;
351370
default: {
352371
GGML_ASSERT(false);

ggml/src/ggml.c

+19-17
Original file line numberDiff line numberDiff line change
@@ -2305,14 +2305,17 @@ struct ggml_tensor * ggml_repeat(
23052305
struct ggml_tensor * ggml_repeat_back(
23062306
struct ggml_context * ctx,
23072307
struct ggml_tensor * a,
2308-
struct ggml_tensor * b) {
2308+
struct ggml_tensor * b,
2309+
bool adjacent) {
23092310
GGML_ASSERT(ggml_can_repeat(b, a));
23102311

23112312
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
23122313

23132314
result->op = GGML_OP_REPEAT_BACK;
23142315
result->src[0] = a;
23152316

2317+
result->op_params[0] = adjacent ? 1 : 0;
2318+
23162319
return result;
23172320
}
23182321

@@ -5299,7 +5302,7 @@ static void ggml_compute_backward(
52995302
if (src1_needs_grads) {
53005303
struct ggml_tensor * tmp = grad;
53015304
if (!ggml_are_same_shape(src0, src1)) {
5302-
tmp = ggml_repeat_back(ctx, tmp, src1);
5305+
tmp = ggml_repeat_back(ctx, tmp, src1, false);
53035306
}
53045307
ggml_add_or_set(ctx, cgraph, isrc1, tmp);
53055308
}
@@ -5339,12 +5342,12 @@ static void ggml_compute_backward(
53395342
} break;
53405343
case GGML_OP_MUL: {
53415344
if (src0_needs_grads) {
5342-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5345+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
53435346
}
53445347
if (src1_needs_grads) {
53455348
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
53465349
if (!ggml_are_same_shape(src0, src1)) {
5347-
tmp = ggml_repeat_back(ctx, tmp, src1);
5350+
tmp = ggml_repeat_back(ctx, tmp, src1, false);
53485351
}
53495352
ggml_add_or_set(ctx, cgraph, isrc1, tmp);
53505353
}
@@ -5399,7 +5402,7 @@ static void ggml_compute_backward(
53995402
} break;
54005403
case GGML_OP_REPEAT: {
54015404
if (src0_needs_grads) {
5402-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
5405+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0, false));
54035406
}
54045407
} break;
54055408
case GGML_OP_REPEAT_BACK: {
@@ -5431,21 +5434,18 @@ static void ggml_compute_backward(
54315434
// src1.shape [n,p,qq,rr]
54325435

54335436
if (src0_needs_grads) {
5434-
struct ggml_tensor * s1_tg =
5437+
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5438+
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5439+
struct ggml_tensor * tmp =
54355440
ggml_out_prod(ctx, // [n,m,qq,rr]
54365441
src1, // [n,p,qq,rr]
54375442
grad); // [m,p,qq,rr]
5438-
const int64_t qq = s1_tg->ne[2];
5439-
const int64_t rr = s1_tg->ne[3];
5440-
const int64_t q1 = src0->ne[2];
5441-
const int64_t r1 = src0->ne[3];
5442-
const bool ne2_broadcasted = qq > q1;
5443-
const bool ne3_broadcasted = rr > r1;
5444-
if (ne2_broadcasted || ne3_broadcasted) {
5445-
// sum broadcast repetitions of s1_tg into shape of src0
5446-
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5443+
if (!ggml_are_same_shape(tmp, src0)) {
5444+
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5445+
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5446+
tmp = ggml_repeat_back(ctx, tmp, src0, true);
54475447
}
5448-
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5448+
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
54495449
}
54505450
if (src1_needs_grads) {
54515451
ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5514,7 +5514,9 @@ static void ggml_compute_backward(
55145514
if (src0_needs_grads) {
55155515
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
55165516
GGML_ASSERT(ggml_is_contiguous(grad));
5517-
ggml_add_or_set(ctx, cgraph, isrc0, grad);
5517+
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5518+
ggml_add_or_set(ctx, cgraph, isrc0,
5519+
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
55185520
}
55195521
} break;
55205522
case GGML_OP_RESHAPE: {

0 commit comments

Comments
 (0)