Skip to content

Commit 9c8dcef

Browse files
CUDA: backwards pass for misc. ops, add tests (#11257)
* CUDA: backwards pass for misc. ops, add tests * remove restrict from pointers
1 parent 681149c commit 9c8dcef

18 files changed

+934
-336
lines changed

ggml/include/ggml.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -1384,16 +1384,20 @@ extern "C" {
13841384
float scale,
13851385
float max_bias);
13861386

1387-
GGML_API struct ggml_tensor * ggml_soft_max_back(
1387+
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
13881388
struct ggml_context * ctx,
13891389
struct ggml_tensor * a,
1390-
struct ggml_tensor * b);
1390+
struct ggml_tensor * b,
1391+
float scale,
1392+
float max_bias);
13911393

13921394
// in-place, returns view(a)
1393-
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
1395+
GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
13941396
struct ggml_context * ctx,
13951397
struct ggml_tensor * a,
1396-
struct ggml_tensor * b);
1398+
struct ggml_tensor * b,
1399+
float scale,
1400+
float max_bias);
13971401

13981402
// rotary position embedding
13991403
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)

ggml/src/ggml-alloc.c

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
3737
return true;
3838
}
3939

40+
// ops that return true for this function must not use restrict pointers for their backend implementations
4041
static bool ggml_op_can_inplace(enum ggml_op op) {
4142
switch (op) {
4243
case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
5253
case GGML_OP_LOG:
5354
case GGML_OP_UNARY:
5455
case GGML_OP_ROPE:
56+
case GGML_OP_ROPE_BACK:
57+
case GGML_OP_SILU_BACK:
5558
case GGML_OP_RMS_NORM:
59+
case GGML_OP_RMS_NORM_BACK:
5660
case GGML_OP_SOFT_MAX:
61+
case GGML_OP_SOFT_MAX_BACK:
5762
return true;
5863

5964
default:

ggml/src/ggml-cpu/ggml-cpu.c

+75-58
Original file line numberDiff line numberDiff line change
@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
66916691
const struct ggml_compute_params * params,
66926692
struct ggml_tensor * dst) {
66936693

6694-
const struct ggml_tensor * src0 = dst->src[0];
6695-
const struct ggml_tensor * grad = dst->src[1];
6694+
const struct ggml_tensor * grad = dst->src[0];
6695+
const struct ggml_tensor * src1 = dst->src[1];
66966696

66976697
assert(ggml_is_contiguous_1(grad));
6698-
assert(ggml_is_contiguous_1(src0));
6698+
assert(ggml_is_contiguous_1(src1));
66996699
assert(ggml_is_contiguous_1(dst));
6700-
assert(ggml_are_same_shape(src0, dst));
6701-
assert(ggml_are_same_shape(src0, grad));
6700+
assert(ggml_are_same_shape(src1, dst));
6701+
assert(ggml_are_same_shape(src1, grad));
67026702

67036703
const int ith = params->ith;
67046704
const int nth = params->nth;
67056705

6706-
const int nc = src0->ne[0];
6707-
const int nr = ggml_nrows(src0);
6706+
const int nc = src1->ne[0];
6707+
const int nr = ggml_nrows(src1);
67086708

67096709
// rows per thread
67106710
const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
67166716
for (int i1 = ir0; i1 < ir1; i1++) {
67176717
ggml_vec_silu_backward_f32(nc,
67186718
(float *) ((char *) dst->data + i1*( dst->nb[1])),
6719-
(float *) ((char *) src0->data + i1*(src0->nb[1])),
6719+
(float *) ((char *) src1->data + i1*(src1->nb[1])),
67206720
(float *) ((char *) grad->data + i1*(grad->nb[1])));
67216721

67226722
#ifndef NDEBUG
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
68956895
float eps;
68966896
memcpy(&eps, dst->op_params, sizeof(float));
68976897

6898-
GGML_ASSERT(eps > 0.0f);
6898+
GGML_ASSERT(eps >= 0.0f);
68996899

69006900
// TODO: optimize
69016901
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
69666966
float eps;
69676967
memcpy(&eps, dst->op_params, sizeof(float));
69686968

6969-
GGML_ASSERT(eps > 0.0f);
6969+
GGML_ASSERT(eps >= 0.0f);
69706970

69716971
// TODO: optimize
69726972
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
70187018
const struct ggml_compute_params * params,
70197019
struct ggml_tensor * dst) {
70207020

7021-
const struct ggml_tensor * src0 = dst->src[0];
7022-
const struct ggml_tensor * src1 = dst->src[1];
7021+
const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
7022+
const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
70237023

70247024
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
70257025

70267026
GGML_ASSERT(src0->nb[0] == sizeof(float));
7027+
GGML_ASSERT(src1->nb[0] == sizeof(float));
70277028

70287029
const int ith = params->ith;
70297030
const int nth = params->nth;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
70427043
const int64_t i12 = i02;
70437044
const int64_t i13 = i03;
70447045

7045-
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7046-
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7046+
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7047+
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
70477048

70487049
ggml_float sum_xx = 0.0;
70497050
ggml_float sum_xdz = 0.0;
@@ -7066,23 +7067,23 @@ static void ggml_compute_forward_rms_norm_back_f32(
70667067
{
70677068
// z = rms_norm(x)
70687069
//
7069-
// rms_norm(src0) =
7070+
// rms_norm(src1) =
70707071
// scale(
7071-
// src0,
7072+
// src1,
70727073
// div(
70737074
// 1,
70747075
// sqrt(
70757076
// add(
70767077
// scale(
70777078
// sum(
70787079
// sqr(
7079-
// src0)),
7080+
// src1)),
70807081
// (1.0/N)),
70817082
// eps))));
70827083

70837084
// postorder:
70847085
// ## op args grad
7085-
// 00 param src0 grad[#00]
7086+
// 00 param src1 grad[#00]
70867087
// 01 const 1
70877088
// 02 sqr (#00) grad[#02]
70887089
// 03 sum (#02) grad[#03]
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
71597160
// dx := scale(dx, rrms)
71607161
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
71617162

7163+
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
71627164
ggml_vec_cpy_f32 (ne00, dx, x);
71637165
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
71647166
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
77507752
const int ith = params->ith;
77517753
const int nth = params->nth;
77527754

7753-
GGML_ASSERT(ne0 == ne00);
7754-
GGML_ASSERT(ne1 == ne10);
7755-
GGML_ASSERT(ne2 == ne02);
7756-
GGML_ASSERT(ne02 == ne12);
7757-
GGML_ASSERT(ne3 == ne13);
7758-
GGML_ASSERT(ne03 == ne13);
7755+
GGML_ASSERT(ne0 == ne00);
7756+
GGML_ASSERT(ne1 == ne10);
7757+
GGML_ASSERT(ne2 == ne12);
7758+
GGML_ASSERT(ne3 == ne13);
7759+
7760+
GGML_ASSERT(ne2 % ne02 == 0);
7761+
GGML_ASSERT(ne3 % ne03 == 0);
77597762

77607763
// we don't support permuted src0 or src1
77617764
GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
77977800
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
77987801
const int64_t blck_1 = 16;
77997802

7803+
// dps == dst per src0, used for group query attention
7804+
const int64_t dps2 = ne2 / ne02;
7805+
const int64_t dps3 = ne3 / ne03;
7806+
78007807
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
78017808
const int64_t bir1 = MIN(bir + blck_1, ir1);
78027809
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
78077814
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
78087815
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
78097816

7810-
const int64_t i02 = i2;
7811-
const int64_t i03 = i3;
7817+
const int64_t i02 = i2 / dps2;
7818+
const int64_t i03 = i3 / dps3;
78127819

78137820
//const int64_t i10 = i1;
78147821
const int64_t i12 = i2;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
89068913
}
89078914

89088915

8909-
// ggml_compute_forward_soft_max_back
8916+
// ggml_compute_forward_soft_max_ext_back
89108917

8911-
static void ggml_compute_forward_soft_max_back_f32(
8918+
static void ggml_compute_forward_soft_max_ext_back_f32(
89128919
const struct ggml_compute_params * params,
89138920
struct ggml_tensor * dst) {
89148921

@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
89218928
GGML_ASSERT(ggml_are_same_shape(src0, dst));
89228929
GGML_ASSERT(ggml_are_same_shape(src1, dst));
89238930

8931+
float scale = 1.0f;
8932+
float max_bias = 0.0f;
8933+
8934+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8935+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8936+
8937+
GGML_ASSERT(max_bias == 0.0f);
8938+
89248939
// TODO: handle transposed/permuted matrices
89258940

89268941
const int ith = params->ith;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
89698984

89708985
// linear runtime, no additional memory
89718986
float dot_y_dy = 0;
8972-
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8973-
ggml_vec_cpy_f32 (nc, dx, dy);
8974-
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
8975-
ggml_vec_mul_f32 (nc, dx, dx, y);
8987+
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8988+
ggml_vec_cpy_f32 (nc, dx, dy);
8989+
ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
8990+
ggml_vec_mul_f32 (nc, dx, dx, y);
8991+
ggml_vec_scale_f32(nc, dx, scale);
89768992

89778993
#ifndef NDEBUG
89788994
for (int i = 0; i < nc; ++i) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
89838999
}
89849000
}
89859001

8986-
static void ggml_compute_forward_soft_max_back(
9002+
static void ggml_compute_forward_soft_max_ext_back(
89879003
const struct ggml_compute_params * params,
89889004
struct ggml_tensor * dst) {
89899005

@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
89929008
switch (src0->type) {
89939009
case GGML_TYPE_F32:
89949010
{
8995-
ggml_compute_forward_soft_max_back_f32(params, dst);
9011+
ggml_compute_forward_soft_max_ext_back_f32(params, dst);
89969012
} break;
89979013
default:
89989014
{
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
998510001
const struct ggml_compute_params * params,
998610002
struct ggml_tensor * dst) {
998710003

9988-
const struct ggml_tensor * src0 = dst->src[0];
9989-
const struct ggml_tensor * src1 = dst->src[1];
10004+
const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
10005+
const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
999010006

10007+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
999110008
GGML_ASSERT(src1->type == GGML_TYPE_F32);
999210009
GGML_ASSERT( dst->type == GGML_TYPE_F32);
999310010

@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
1000910026
const int64_t IH = is_2D ? ne1 : 1;
1001010027
const int64_t IW = ne0;
1001110028

10012-
const int64_t KH = is_2D ? ne01 : 1;
10013-
const int64_t KW = ne00;
10029+
const int64_t KH = is_2D ? ne11 : 1;
10030+
const int64_t KW = ne10;
1001410031

10015-
const int64_t OH = is_2D ? ne12 : 1;
10016-
const int64_t OW = ne11;
10032+
const int64_t OH = is_2D ? ne02 : 1;
10033+
const int64_t OW = ne01;
1001710034

1001810035
int ofs0 = is_2D ? nb3 : nb2;
1001910036
int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
1005910076
continue;
1006010077
}
1006110078

10062-
const float * const src_data = (const float *) src1->data
10079+
const float * const grad_in = (const float *) src0->data
1006310080
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
10064-
grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
10081+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
1006510082
}
1006610083
}
1006710084
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1248412501
const struct ggml_compute_params * params,
1248512502
struct ggml_tensor * dst) {
1248612503

12487-
const struct ggml_tensor * src0 = dst->src[0];
12488-
const struct ggml_tensor * src1 = dst->src[1];
12489-
const struct ggml_tensor * opt0 = dst->src[2];
12504+
const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
12505+
const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
12506+
const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
1249012507

1249112508
GGML_ASSERT(ggml_is_contiguous(dst));
12492-
GGML_ASSERT(ggml_is_contiguous(src0));
12493-
GGML_ASSERT(ggml_is_contiguous(src1));
12494-
GGML_ASSERT(ggml_is_contiguous(opt0));
12495-
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
12509+
GGML_ASSERT(ggml_is_contiguous(src0f));
12510+
GGML_ASSERT(ggml_is_contiguous(src1f));
12511+
GGML_ASSERT(ggml_is_contiguous(grad));
12512+
GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
1249612513

1249712514
const int64_t ith = params->ith;
1249812515
const int64_t nth = params->nth;
1249912516

1250012517
// TODO: handle transposed/permuted matrices
12501-
const int64_t nc = src0->ne[0];
12502-
const int64_t nr = ggml_nrows(src0);
12518+
const int64_t nc = src0f->ne[0];
12519+
const int64_t nr = ggml_nrows(src0f);
1250312520

1250412521
// rows per thread
1250512522
const int64_t dr = (nr + nth - 1)/nth;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1250812525
const int64_t ir0 = dr*ith;
1250912526
const int64_t ir1 = MIN(ir0 + dr, nr);
1251012527

12511-
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
12528+
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
1251212529

1251312530
for (int64_t i1 = ir0; i1 < ir1; i1++) {
12514-
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12515-
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
12516-
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
12531+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12532+
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
12533+
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
1251712534

1251812535
#ifndef NDEBUG
1251912536
for (int64_t i = 0; i < nc; ++i) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1252612543
// soft_max
1252712544
float max = -INFINITY;
1252812545
ggml_vec_max_f32(nc, &max, s0);
12529-
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
12546+
const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
1253012547
assert(sum > 0.0);
1253112548
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
1253212549

12533-
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
12550+
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
1253412551
ggml_vec_sub_f32(nc, ds0, ds0, s1);
1253512552
ggml_vec_scale_f32(nc, ds0, d_by_nr);
1253612553

@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1282712844
} break;
1282812845
case GGML_OP_SOFT_MAX_BACK:
1282912846
{
12830-
ggml_compute_forward_soft_max_back(params, tensor);
12847+
ggml_compute_forward_soft_max_ext_back(params, tensor);
1283112848
} break;
1283212849
case GGML_OP_ROPE:
1283312850
{

ggml/src/ggml-cpu/ggml-cpu.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
403403
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
404404
case GGML_OP_MUL_MAT:
405405
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
406+
case GGML_OP_SOFT_MAX_BACK: {
407+
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
408+
return false;
409+
}
410+
float max_bias = 0.0f;
411+
412+
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
413+
414+
return max_bias == 0.0f;
415+
}
406416
case GGML_OP_IM2COL_BACK:
407417
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
408418
case GGML_OP_OUT_PROD:

0 commit comments

Comments
 (0)