@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
6691
6691
const struct ggml_compute_params * params ,
6692
6692
struct ggml_tensor * dst ) {
6693
6693
6694
- const struct ggml_tensor * src0 = dst -> src [0 ];
6695
- const struct ggml_tensor * grad = dst -> src [1 ];
6694
+ const struct ggml_tensor * grad = dst -> src [0 ];
6695
+ const struct ggml_tensor * src1 = dst -> src [1 ];
6696
6696
6697
6697
assert (ggml_is_contiguous_1 (grad ));
6698
- assert (ggml_is_contiguous_1 (src0 ));
6698
+ assert (ggml_is_contiguous_1 (src1 ));
6699
6699
assert (ggml_is_contiguous_1 (dst ));
6700
- assert (ggml_are_same_shape (src0 , dst ));
6701
- assert (ggml_are_same_shape (src0 , grad ));
6700
+ assert (ggml_are_same_shape (src1 , dst ));
6701
+ assert (ggml_are_same_shape (src1 , grad ));
6702
6702
6703
6703
const int ith = params -> ith ;
6704
6704
const int nth = params -> nth ;
6705
6705
6706
- const int nc = src0 -> ne [0 ];
6707
- const int nr = ggml_nrows (src0 );
6706
+ const int nc = src1 -> ne [0 ];
6707
+ const int nr = ggml_nrows (src1 );
6708
6708
6709
6709
// rows per thread
6710
6710
const int dr = (nr + nth - 1 )/nth ;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
6716
6716
for (int i1 = ir0 ; i1 < ir1 ; i1 ++ ) {
6717
6717
ggml_vec_silu_backward_f32 (nc ,
6718
6718
(float * ) ((char * ) dst -> data + i1 * ( dst -> nb [1 ])),
6719
- (float * ) ((char * ) src0 -> data + i1 * (src0 -> nb [1 ])),
6719
+ (float * ) ((char * ) src1 -> data + i1 * (src1 -> nb [1 ])),
6720
6720
(float * ) ((char * ) grad -> data + i1 * (grad -> nb [1 ])));
6721
6721
6722
6722
#ifndef NDEBUG
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
6895
6895
float eps ;
6896
6896
memcpy (& eps , dst -> op_params , sizeof (float ));
6897
6897
6898
- GGML_ASSERT (eps > 0.0f );
6898
+ GGML_ASSERT (eps >= 0.0f );
6899
6899
6900
6900
// TODO: optimize
6901
6901
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
6966
6966
float eps ;
6967
6967
memcpy (& eps , dst -> op_params , sizeof (float ));
6968
6968
6969
- GGML_ASSERT (eps > 0.0f );
6969
+ GGML_ASSERT (eps >= 0.0f );
6970
6970
6971
6971
// TODO: optimize
6972
6972
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
7018
7018
const struct ggml_compute_params * params ,
7019
7019
struct ggml_tensor * dst ) {
7020
7020
7021
- const struct ggml_tensor * src0 = dst -> src [0 ];
7022
- const struct ggml_tensor * src1 = dst -> src [1 ];
7021
+ const struct ggml_tensor * src0 = dst -> src [0 ]; // gradients from forward pass output
7022
+ const struct ggml_tensor * src1 = dst -> src [1 ]; // src1 from forward pass
7023
7023
7024
7024
GGML_ASSERT (ggml_are_same_shape (src0 , dst ) && ggml_are_same_shape (src0 , src1 ));
7025
7025
7026
7026
GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
7027
+ GGML_ASSERT (src1 -> nb [0 ] == sizeof (float ));
7027
7028
7028
7029
const int ith = params -> ith ;
7029
7030
const int nth = params -> nth ;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
7042
7043
const int64_t i12 = i02 ;
7043
7044
const int64_t i13 = i03 ;
7044
7045
7045
- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
7046
- const float * dz = (float * ) ((char * ) src1 -> data + i11 * nb11 + i12 * nb12 + i13 * nb13 );
7046
+ const float * dz = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
7047
+ const float * x = (float * ) ((char * ) src1 -> data + i11 * nb11 + i12 * nb12 + i13 * nb13 );
7047
7048
7048
7049
ggml_float sum_xx = 0.0 ;
7049
7050
ggml_float sum_xdz = 0.0 ;
@@ -7066,23 +7067,23 @@ static void ggml_compute_forward_rms_norm_back_f32(
7066
7067
{
7067
7068
// z = rms_norm(x)
7068
7069
//
7069
- // rms_norm(src0 ) =
7070
+ // rms_norm(src1 ) =
7070
7071
// scale(
7071
- // src0 ,
7072
+ // src1 ,
7072
7073
// div(
7073
7074
// 1,
7074
7075
// sqrt(
7075
7076
// add(
7076
7077
// scale(
7077
7078
// sum(
7078
7079
// sqr(
7079
- // src0 )),
7080
+ // src1 )),
7080
7081
// (1.0/N)),
7081
7082
// eps))));
7082
7083
7083
7084
// postorder:
7084
7085
// ## op args grad
7085
- // 00 param src0 grad[#00]
7086
+ // 00 param src1 grad[#00]
7086
7087
// 01 const 1
7087
7088
// 02 sqr (#00) grad[#02]
7088
7089
// 03 sum (#02) grad[#03]
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
7159
7160
// dx := scale(dx, rrms)
7160
7161
float * dx = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
7161
7162
7163
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
7162
7164
ggml_vec_cpy_f32 (ne00 , dx , x );
7163
7165
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7164
7166
ggml_vec_scale_f32 (ne00 , dx , (float )(- sum_xdz )/sum_eps );
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
7750
7752
const int ith = params -> ith ;
7751
7753
const int nth = params -> nth ;
7752
7754
7753
- GGML_ASSERT (ne0 == ne00 );
7754
- GGML_ASSERT (ne1 == ne10 );
7755
- GGML_ASSERT (ne2 == ne02 );
7756
- GGML_ASSERT (ne02 == ne12 );
7757
- GGML_ASSERT (ne3 == ne13 );
7758
- GGML_ASSERT (ne03 == ne13 );
7755
+ GGML_ASSERT (ne0 == ne00 );
7756
+ GGML_ASSERT (ne1 == ne10 );
7757
+ GGML_ASSERT (ne2 == ne12 );
7758
+ GGML_ASSERT (ne3 == ne13 );
7759
+
7760
+ GGML_ASSERT (ne2 % ne02 == 0 );
7761
+ GGML_ASSERT (ne3 % ne03 == 0 );
7759
7762
7760
7763
// we don't support permuted src0 or src1
7761
7764
GGML_ASSERT (nb00 == sizeof (float ));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
7797
7800
const int64_t blck_0 = MAX (GGML_VEC_MAD_UNROLL , 32 );
7798
7801
const int64_t blck_1 = 16 ;
7799
7802
7803
+ // dps == dst per src0, used for group query attention
7804
+ const int64_t dps2 = ne2 / ne02 ;
7805
+ const int64_t dps3 = ne3 / ne03 ;
7806
+
7800
7807
for (int64_t bir = ir0 ; bir < ir1 ; bir += blck_1 ) {
7801
7808
const int64_t bir1 = MIN (bir + blck_1 , ir1 );
7802
7809
for (int64_t bi01 = 0 ; bi01 < ne01 ; bi01 += blck_0 ) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
7807
7814
const int64_t i2 = (ir - i3 * ne2 * ne1 )/ne1 ;
7808
7815
const int64_t i1 = (ir - i3 * ne2 * ne1 - i2 * ne1 );
7809
7816
7810
- const int64_t i02 = i2 ;
7811
- const int64_t i03 = i3 ;
7817
+ const int64_t i02 = i2 / dps2 ;
7818
+ const int64_t i03 = i3 / dps3 ;
7812
7819
7813
7820
//const int64_t i10 = i1;
7814
7821
const int64_t i12 = i2 ;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
8906
8913
}
8907
8914
8908
8915
8909
- // ggml_compute_forward_soft_max_back
8916
+ // ggml_compute_forward_soft_max_ext_back
8910
8917
8911
- static void ggml_compute_forward_soft_max_back_f32 (
8918
+ static void ggml_compute_forward_soft_max_ext_back_f32 (
8912
8919
const struct ggml_compute_params * params ,
8913
8920
struct ggml_tensor * dst ) {
8914
8921
@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
8921
8928
GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
8922
8929
GGML_ASSERT (ggml_are_same_shape (src1 , dst ));
8923
8930
8931
+ float scale = 1.0f ;
8932
+ float max_bias = 0.0f ;
8933
+
8934
+ memcpy (& scale , (const float * ) dst -> op_params + 0 , sizeof (float ));
8935
+ memcpy (& max_bias , (const float * ) dst -> op_params + 1 , sizeof (float ));
8936
+
8937
+ GGML_ASSERT (max_bias == 0.0f );
8938
+
8924
8939
// TODO: handle transposed/permuted matrices
8925
8940
8926
8941
const int ith = params -> ith ;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
8969
8984
8970
8985
// linear runtime, no additional memory
8971
8986
float dot_y_dy = 0 ;
8972
- ggml_vec_dot_f32 (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8973
- ggml_vec_cpy_f32 (nc , dx , dy );
8974
- ggml_vec_acc1_f32 (nc , dx , - dot_y_dy );
8975
- ggml_vec_mul_f32 (nc , dx , dx , y );
8987
+ ggml_vec_dot_f32 (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8988
+ ggml_vec_cpy_f32 (nc , dx , dy );
8989
+ ggml_vec_acc1_f32 (nc , dx , - dot_y_dy );
8990
+ ggml_vec_mul_f32 (nc , dx , dx , y );
8991
+ ggml_vec_scale_f32 (nc , dx , scale );
8976
8992
8977
8993
#ifndef NDEBUG
8978
8994
for (int i = 0 ; i < nc ; ++ i ) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
8983
8999
}
8984
9000
}
8985
9001
8986
- static void ggml_compute_forward_soft_max_back (
9002
+ static void ggml_compute_forward_soft_max_ext_back (
8987
9003
const struct ggml_compute_params * params ,
8988
9004
struct ggml_tensor * dst ) {
8989
9005
@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
8992
9008
switch (src0 -> type ) {
8993
9009
case GGML_TYPE_F32 :
8994
9010
{
8995
- ggml_compute_forward_soft_max_back_f32 (params , dst );
9011
+ ggml_compute_forward_soft_max_ext_back_f32 (params , dst );
8996
9012
} break ;
8997
9013
default :
8998
9014
{
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
9985
10001
const struct ggml_compute_params * params ,
9986
10002
struct ggml_tensor * dst ) {
9987
10003
9988
- const struct ggml_tensor * src0 = dst -> src [0 ];
9989
- const struct ggml_tensor * src1 = dst -> src [1 ];
10004
+ const struct ggml_tensor * src0 = dst -> src [0 ]; // gradients of forward pass output
10005
+ const struct ggml_tensor * src1 = dst -> src [1 ]; // convolution kernel
9990
10006
10007
+ GGML_ASSERT (src0 -> type == GGML_TYPE_F32 );
9991
10008
GGML_ASSERT (src1 -> type == GGML_TYPE_F32 );
9992
10009
GGML_ASSERT ( dst -> type == GGML_TYPE_F32 );
9993
10010
@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
10009
10026
const int64_t IH = is_2D ? ne1 : 1 ;
10010
10027
const int64_t IW = ne0 ;
10011
10028
10012
- const int64_t KH = is_2D ? ne01 : 1 ;
10013
- const int64_t KW = ne00 ;
10029
+ const int64_t KH = is_2D ? ne11 : 1 ;
10030
+ const int64_t KW = ne10 ;
10014
10031
10015
- const int64_t OH = is_2D ? ne12 : 1 ;
10016
- const int64_t OW = ne11 ;
10032
+ const int64_t OH = is_2D ? ne02 : 1 ;
10033
+ const int64_t OW = ne01 ;
10017
10034
10018
10035
int ofs0 = is_2D ? nb3 : nb2 ;
10019
10036
int ofs1 = is_2D ? nb2 : nb1 ;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
10059
10076
continue ;
10060
10077
}
10061
10078
10062
- const float * const src_data = (const float * ) src1 -> data
10079
+ const float * const grad_in = (const float * ) src0 -> data
10063
10080
+ (in * OH * OW + ioh * OW + iow )* (IC * KH * KW ); // [IC, KH, KW]
10064
- grad += src_data [iic * (KH * KW ) + ikh * KW + ikw ];
10081
+ grad += grad_in [iic * (KH * KW ) + ikh * KW + ikw ];
10065
10082
}
10066
10083
}
10067
10084
float * dst_data = (float * )((char * ) wdata + (in * ofs0 + iic * ofs1 )); // [IH, IW]
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12484
12501
const struct ggml_compute_params * params ,
12485
12502
struct ggml_tensor * dst ) {
12486
12503
12487
- const struct ggml_tensor * src0 = dst -> src [0 ];
12488
- const struct ggml_tensor * src1 = dst -> src [1 ];
12489
- const struct ggml_tensor * opt0 = dst -> src [2 ];
12504
+ const struct ggml_tensor * grad = dst -> src [0 ]; // gradient of forward pass output
12505
+ const struct ggml_tensor * src0f = dst -> src [1 ]; // src0 of forward pass
12506
+ const struct ggml_tensor * src1f = dst -> src [2 ]; // src1 of forward pass
12490
12507
12491
12508
GGML_ASSERT (ggml_is_contiguous (dst ));
12492
- GGML_ASSERT (ggml_is_contiguous (src0 ));
12493
- GGML_ASSERT (ggml_is_contiguous (src1 ));
12494
- GGML_ASSERT (ggml_is_contiguous (opt0 ));
12495
- GGML_ASSERT (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
12509
+ GGML_ASSERT (ggml_is_contiguous (src0f ));
12510
+ GGML_ASSERT (ggml_is_contiguous (src1f ));
12511
+ GGML_ASSERT (ggml_is_contiguous (grad ));
12512
+ GGML_ASSERT (ggml_are_same_shape (src0f , src1f ) && ggml_are_same_shape (src0f , dst ));
12496
12513
12497
12514
const int64_t ith = params -> ith ;
12498
12515
const int64_t nth = params -> nth ;
12499
12516
12500
12517
// TODO: handle transposed/permuted matrices
12501
- const int64_t nc = src0 -> ne [0 ];
12502
- const int64_t nr = ggml_nrows (src0 );
12518
+ const int64_t nc = src0f -> ne [0 ];
12519
+ const int64_t nr = ggml_nrows (src0f );
12503
12520
12504
12521
// rows per thread
12505
12522
const int64_t dr = (nr + nth - 1 )/nth ;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12508
12525
const int64_t ir0 = dr * ith ;
12509
12526
const int64_t ir1 = MIN (ir0 + dr , nr );
12510
12527
12511
- const float d_by_nr = ((const float * ) opt0 -> data )[0 ] / (float ) nr ;
12528
+ const float d_by_nr = ((const float * ) grad -> data )[0 ] / (float ) nr ;
12512
12529
12513
12530
for (int64_t i1 = ir0 ; i1 < ir1 ; i1 ++ ) {
12514
- float * ds0 = (float * )((char * ) dst -> data + i1 * dst -> nb [1 ]);
12515
- float * s0 = (float * )((char * ) src0 -> data + i1 * src0 -> nb [1 ]);
12516
- float * s1 = (float * )((char * ) src1 -> data + i1 * src1 -> nb [1 ]);
12531
+ float * ds0 = (float * )((char * ) dst -> data + i1 * dst -> nb [1 ]);
12532
+ const float * s0 = (const float * )((const char * ) src0f -> data + i1 * src0f -> nb [1 ]);
12533
+ const float * s1 = (const float * )((const char * ) src1f -> data + i1 * src1f -> nb [1 ]);
12517
12534
12518
12535
#ifndef NDEBUG
12519
12536
for (int64_t i = 0 ; i < nc ; ++ i ) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12526
12543
// soft_max
12527
12544
float max = - INFINITY ;
12528
12545
ggml_vec_max_f32 (nc , & max , s0 );
12529
- ggml_float sum = ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
12546
+ const ggml_float sum = ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
12530
12547
assert (sum > 0.0 );
12531
12548
ggml_vec_scale_f32 (nc , ds0 , 1.0 /sum );
12532
12549
12533
- // grad(src0 ) = (softmax(src0 ) - src1 ) * grad(cross_entropy_loss(src0, src1 )) / nr
12550
+ // grad(src0f ) = (softmax(src0f ) - src1f ) * grad(cross_entropy_loss(src0f, src1f )) / nr
12534
12551
ggml_vec_sub_f32 (nc , ds0 , ds0 , s1 );
12535
12552
ggml_vec_scale_f32 (nc , ds0 , d_by_nr );
12536
12553
@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12827
12844
} break ;
12828
12845
case GGML_OP_SOFT_MAX_BACK :
12829
12846
{
12830
- ggml_compute_forward_soft_max_back (params , tensor );
12847
+ ggml_compute_forward_soft_max_ext_back (params , tensor );
12831
12848
} break ;
12832
12849
case GGML_OP_ROPE :
12833
12850
{
0 commit comments