@@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
66916691        const  struct  ggml_compute_params  *  params ,
66926692        struct  ggml_tensor  *  dst ) {
66936693
6694-     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];
6695-     const  struct  ggml_tensor  *  grad  =  dst -> src [1 ];
6694+     const  struct  ggml_tensor  *  grad  =  dst -> src [0 ];
6695+     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];
66966696
66976697    assert (ggml_is_contiguous_1 (grad ));
6698-     assert (ggml_is_contiguous_1 (src0 ));
6698+     assert (ggml_is_contiguous_1 (src1 ));
66996699    assert (ggml_is_contiguous_1 (dst ));
6700-     assert (ggml_are_same_shape (src0 , dst ));
6701-     assert (ggml_are_same_shape (src0 , grad ));
6700+     assert (ggml_are_same_shape (src1 , dst ));
6701+     assert (ggml_are_same_shape (src1 , grad ));
67026702
67036703    const  int  ith  =  params -> ith ;
67046704    const  int  nth  =  params -> nth ;
67056705
6706-     const  int  nc  =  src0 -> ne [0 ];
6707-     const  int  nr  =  ggml_nrows (src0 );
6706+     const  int  nc  =  src1 -> ne [0 ];
6707+     const  int  nr  =  ggml_nrows (src1 );
67086708
67096709    // rows per thread 
67106710    const  int  dr  =  (nr  +  nth  -  1 )/nth ;
@@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
67166716    for  (int  i1  =  ir0 ; i1  <  ir1 ; i1 ++ ) {
67176717        ggml_vec_silu_backward_f32 (nc ,
67186718                (float  * ) ((char  * ) dst -> data   +  i1 * ( dst -> nb [1 ])),
6719-                 (float  * ) ((char  * ) src0 -> data  +  i1 * (src0 -> nb [1 ])),
6719+                 (float  * ) ((char  * ) src1 -> data  +  i1 * (src1 -> nb [1 ])),
67206720                (float  * ) ((char  * ) grad -> data  +  i1 * (grad -> nb [1 ])));
67216721
67226722#ifndef  NDEBUG 
@@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
68956895    float  eps ;
68966896    memcpy (& eps , dst -> op_params , sizeof (float ));
68976897
6898-     GGML_ASSERT (eps  >  0.0f );
6898+     GGML_ASSERT (eps  >=  0.0f );
68996899
69006900    // TODO: optimize 
69016901    for  (int64_t  i03  =  0 ; i03  <  ne03 ; i03 ++ ) {
@@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
69666966    float  eps ;
69676967    memcpy (& eps , dst -> op_params , sizeof (float ));
69686968
6969-     GGML_ASSERT (eps  >  0.0f );
6969+     GGML_ASSERT (eps  >=  0.0f );
69706970
69716971    // TODO: optimize 
69726972    for  (int64_t  i03  =  0 ; i03  <  ne03 ; i03 ++ ) {
@@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
70187018        const  struct  ggml_compute_params  *  params ,
70197019        struct  ggml_tensor  *  dst ) {
70207020
7021-     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];
7022-     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];
7021+     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];  // gradients from forward pass output 
7022+     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];  // src1 from forward pass 
70237023
70247024    GGML_ASSERT (ggml_are_same_shape (src0 , dst ) &&  ggml_are_same_shape (src0 , src1 ));
70257025
70267026    GGML_ASSERT (src0 -> nb [0 ] ==  sizeof (float ));
7027+     GGML_ASSERT (src1 -> nb [0 ] ==  sizeof (float ));
70277028
70287029    const  int  ith  =  params -> ith ;
70297030    const  int  nth  =  params -> nth ;
@@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
70427043                const  int64_t  i12  =  i02 ;
70437044                const  int64_t  i13  =  i03 ;
70447045
7045-                 const  float  *  x  =  (float  * ) ((char  * ) src0 -> data  +  i01 * nb01  +  i02 * nb02  +  i03 * nb03 );
7046-                 const  float  *  dz  =  (float  * ) ((char  * ) src1 -> data  +  i11 * nb11  +  i12 * nb12  +  i13 * nb13 );
7046+                 const  float  *  dz  =  (float  * ) ((char  * ) src0 -> data  +  i01 * nb01  +  i02 * nb02  +  i03 * nb03 );
7047+                 const  float  *  x    =  (float  * ) ((char  * ) src1 -> data  +  i11 * nb11  +  i12 * nb12  +  i13 * nb13 );
70477048
70487049                ggml_float  sum_xx   =  0.0 ;
70497050                ggml_float  sum_xdz  =  0.0 ;
@@ -7066,23 +7067,23 @@ static void ggml_compute_forward_rms_norm_back_f32(
70667067                {
70677068                    // z = rms_norm(x) 
70687069                    // 
7069-                     // rms_norm(src0 ) = 
7070+                     // rms_norm(src1 ) = 
70707071                    //     scale( 
7071-                     //         src0 , 
7072+                     //         src1 , 
70727073                    //         div( 
70737074                    //             1, 
70747075                    //             sqrt( 
70757076                    //                 add( 
70767077                    //                     scale( 
70777078                    //                         sum( 
70787079                    //                             sqr( 
7079-                     //                                 src0 )), 
7080+                     //                                 src1 )), 
70807081                    //                         (1.0/N)), 
70817082                    //                     eps)))); 
70827083
70837084                    // postorder: 
70847085                    // ## op    args         grad 
7085-                     // 00 param src0          grad[#00] 
7086+                     // 00 param src1          grad[#00] 
70867087                    // 01 const 1 
70877088                    // 02 sqr   (#00)        grad[#02] 
70887089                    // 03 sum   (#02)        grad[#03] 
@@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
71597160                // dx := scale(dx, rrms) 
71607161                float  *  dx  =  (float  * ) ((char  * ) dst -> data  +  i01 * nb1  +  i02 * nb2  +  i03 * nb3 );
71617162
7163+                 // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) 
71627164                ggml_vec_cpy_f32   (ne00 , dx , x );
71637165                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); 
71647166                ggml_vec_scale_f32 (ne00 , dx , (float )(- sum_xdz )/sum_eps );
@@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
77507752    const  int  ith  =  params -> ith ;
77517753    const  int  nth  =  params -> nth ;
77527754
7753-     GGML_ASSERT (ne0   ==  ne00 );
7754-     GGML_ASSERT (ne1   ==  ne10 );
7755-     GGML_ASSERT (ne2   ==  ne02 );
7756-     GGML_ASSERT (ne02  ==  ne12 );
7757-     GGML_ASSERT (ne3   ==  ne13 );
7758-     GGML_ASSERT (ne03  ==  ne13 );
7755+     GGML_ASSERT (ne0  ==  ne00 );
7756+     GGML_ASSERT (ne1  ==  ne10 );
7757+     GGML_ASSERT (ne2  ==  ne12 );
7758+     GGML_ASSERT (ne3  ==  ne13 );
7759+ 
7760+     GGML_ASSERT (ne2  % ne02  ==  0 );
7761+     GGML_ASSERT (ne3  % ne03  ==  0 );
77597762
77607763    // we don't support permuted src0 or src1 
77617764    GGML_ASSERT (nb00  ==  sizeof (float ));
@@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
77977800    const  int64_t  blck_0  =  MAX (GGML_VEC_MAD_UNROLL , 32 );
77987801    const  int64_t  blck_1  =  16 ;
77997802
7803+     // dps == dst per src0, used for group query attention 
7804+     const  int64_t  dps2  =  ne2  / ne02 ;
7805+     const  int64_t  dps3  =  ne3  / ne03 ;
7806+ 
78007807    for  (int64_t  bir  =  ir0 ; bir  <  ir1 ; bir  +=  blck_1 ) {
78017808        const  int64_t  bir1  =  MIN (bir  +  blck_1 , ir1 );
78027809        for  (int64_t  bi01  =  0 ; bi01  <  ne01 ; bi01  +=  blck_0 ) {
@@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
78077814                const  int64_t  i2  =  (ir  -  i3 * ne2 * ne1 )/ne1 ;
78087815                const  int64_t  i1  =  (ir  -  i3 * ne2 * ne1  -  i2 * ne1 );
78097816
7810-                 const  int64_t  i02  =  i2 ;
7811-                 const  int64_t  i03  =  i3 ;
7817+                 const  int64_t  i02  =  i2  /  dps2 ;
7818+                 const  int64_t  i03  =  i3  /  dps3 ;
78127819
78137820                //const int64_t i10 = i1; 
78147821                const  int64_t  i12  =  i2 ;
@@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
89068913}
89078914
89088915
8909- // ggml_compute_forward_soft_max_back  
8916+ // ggml_compute_forward_soft_max_ext_back  
89108917
8911- static  void  ggml_compute_forward_soft_max_back_f32 (
8918+ static  void  ggml_compute_forward_soft_max_ext_back_f32 (
89128919        const  struct  ggml_compute_params  *  params ,
89138920        struct  ggml_tensor  *  dst ) {
89148921
@@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
89218928    GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
89228929    GGML_ASSERT (ggml_are_same_shape (src1 , dst ));
89238930
8931+     float  scale     =  1.0f ;
8932+     float  max_bias  =  0.0f ;
8933+ 
8934+     memcpy (& scale ,    (const  float  * ) dst -> op_params  +  0 , sizeof (float ));
8935+     memcpy (& max_bias , (const  float  * ) dst -> op_params  +  1 , sizeof (float ));
8936+ 
8937+     GGML_ASSERT (max_bias  ==  0.0f );
8938+ 
89248939    // TODO: handle transposed/permuted matrices 
89258940
89268941    const  int  ith  =  params -> ith ;
@@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
89698984
89708985        // linear runtime, no additional memory 
89718986        float  dot_y_dy  =  0 ;
8972-         ggml_vec_dot_f32  (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8973-         ggml_vec_cpy_f32  (nc , dx , dy );
8974-         ggml_vec_acc1_f32 (nc , dx , - dot_y_dy );
8975-         ggml_vec_mul_f32  (nc , dx , dx , y );
8987+         ggml_vec_dot_f32   (nc , & dot_y_dy , 0 , y , 0 , dy , 0 , 1 );
8988+         ggml_vec_cpy_f32   (nc , dx , dy );
8989+         ggml_vec_acc1_f32  (nc , dx , - dot_y_dy );
8990+         ggml_vec_mul_f32   (nc , dx , dx , y );
8991+         ggml_vec_scale_f32 (nc , dx , scale );
89768992
89778993#ifndef  NDEBUG 
89788994        for  (int  i  =  0 ; i  <  nc ; ++ i ) {
@@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
89838999    }
89849000}
89859001
8986- static  void  ggml_compute_forward_soft_max_back (
9002+ static  void  ggml_compute_forward_soft_max_ext_back (
89879003        const  struct  ggml_compute_params  *  params ,
89889004        struct  ggml_tensor  *  dst ) {
89899005
@@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
89929008    switch  (src0 -> type ) {
89939009        case  GGML_TYPE_F32 :
89949010            {
8995-                 ggml_compute_forward_soft_max_back_f32 (params , dst );
9011+                 ggml_compute_forward_soft_max_ext_back_f32 (params , dst );
89969012            } break ;
89979013        default :
89989014            {
@@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
998510001        const  struct  ggml_compute_params  *  params ,
998610002              struct  ggml_tensor  *  dst ) {
998710003
9988-     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];
9989-     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];
10004+     const  struct  ggml_tensor  *  src0  =  dst -> src [0 ];  // gradients of forward pass output 
10005+     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];  // convolution kernel 
999010006
10007+     GGML_ASSERT (src0 -> type  ==  GGML_TYPE_F32 );
999110008    GGML_ASSERT (src1 -> type  ==  GGML_TYPE_F32 );
999210009    GGML_ASSERT ( dst -> type  ==  GGML_TYPE_F32 );
999310010
@@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
1000910026    const  int64_t  IH  =  is_2D  ? ne1  : 1 ;
1001010027    const  int64_t  IW  =  ne0 ;
1001110028
10012-     const  int64_t  KH  =  is_2D  ? ne01  : 1 ;
10013-     const  int64_t  KW  =  ne00 ;
10029+     const  int64_t  KH  =  is_2D  ? ne11  : 1 ;
10030+     const  int64_t  KW  =  ne10 ;
1001410031
10015-     const  int64_t  OH  =  is_2D  ? ne12  : 1 ;
10016-     const  int64_t  OW  =  ne11 ;
10032+     const  int64_t  OH  =  is_2D  ? ne02  : 1 ;
10033+     const  int64_t  OW  =  ne01 ;
1001710034
1001810035    int  ofs0  =  is_2D  ? nb3  : nb2 ;
1001910036    int  ofs1  =  is_2D  ? nb2  : nb1 ;
@@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
1005910076                                    continue ;
1006010077                                }
1006110078
10062-                                 const  float  *  const  src_data  =  (const  float  * ) src1 -> data 
10079+                                 const  float  *  const  grad_in  =  (const  float  * ) src0 -> data 
1006310080                                    +  (in * OH * OW  +  ioh * OW  +  iow )* (IC * KH * KW ); // [IC, KH, KW] 
10064-                                 grad  +=  src_data [iic * (KH * KW ) +  ikh * KW  +  ikw ];
10081+                                 grad  +=  grad_in [iic * (KH * KW ) +  ikh * KW  +  ikw ];
1006510082                            }
1006610083                        }
1006710084                        float  *  dst_data  =  (float  * )((char  * ) wdata  +  (in * ofs0  +  iic * ofs1 )); // [IH, IW] 
@@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1248412501        const  struct  ggml_compute_params  *  params ,
1248512502        struct  ggml_tensor  *  dst ) {
1248612503
12487-     const  struct  ggml_tensor  *  src0   =  dst -> src [0 ];
12488-     const  struct  ggml_tensor  *  src1  =  dst -> src [1 ];
12489-     const  struct  ggml_tensor  *  opt0  =  dst -> src [2 ];
12504+     const  struct  ggml_tensor  *  grad    =  dst -> src [0 ];  // gradient of forward pass output 
12505+     const  struct  ggml_tensor  *  src0f  =  dst -> src [1 ];  // src0 of forward pass 
12506+     const  struct  ggml_tensor  *  src1f  =  dst -> src [2 ];  // src1 of forward pass 
1249012507
1249112508    GGML_ASSERT (ggml_is_contiguous (dst ));
12492-     GGML_ASSERT (ggml_is_contiguous (src0 ));
12493-     GGML_ASSERT (ggml_is_contiguous (src1 ));
12494-     GGML_ASSERT (ggml_is_contiguous (opt0 ));
12495-     GGML_ASSERT (ggml_are_same_shape (src0 ,  src1 ) &&  ggml_are_same_shape (src0 , dst ));
12509+     GGML_ASSERT (ggml_is_contiguous (src0f ));
12510+     GGML_ASSERT (ggml_is_contiguous (src1f ));
12511+     GGML_ASSERT (ggml_is_contiguous (grad ));
12512+     GGML_ASSERT (ggml_are_same_shape (src0f ,  src1f ) &&  ggml_are_same_shape (src0f , dst ));
1249612513
1249712514    const  int64_t  ith  =  params -> ith ;
1249812515    const  int64_t  nth  =  params -> nth ;
1249912516
1250012517    // TODO: handle transposed/permuted matrices 
12501-     const  int64_t  nc  =  src0 -> ne [0 ];
12502-     const  int64_t  nr  =  ggml_nrows (src0 );
12518+     const  int64_t  nc  =  src0f -> ne [0 ];
12519+     const  int64_t  nr  =  ggml_nrows (src0f );
1250312520
1250412521    // rows per thread 
1250512522    const  int64_t  dr  =  (nr  +  nth  -  1 )/nth ;
@@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1250812525    const  int64_t  ir0  =  dr * ith ;
1250912526    const  int64_t  ir1  =  MIN (ir0  +  dr , nr );
1251012527
12511-     const  float  d_by_nr  =  ((const  float  * ) opt0 -> data )[0 ] / (float ) nr ;
12528+     const  float  d_by_nr  =  ((const  float  * ) grad -> data )[0 ] / (float ) nr ;
1251212529
1251312530    for  (int64_t  i1  =  ir0 ; i1  <  ir1 ; i1 ++ ) {
12514-         float  *  ds0  =  (float  * )((char  * ) dst -> data   +  i1 * dst -> nb [1 ]);
12515-         float  *  s0   =  (float  * )((char  * ) src0 -> data  +  i1 * src0 -> nb [1 ]);
12516-         float  *  s1   =  (float  * )((char  * ) src1 -> data  +  i1 * src1 -> nb [1 ]);
12531+         float         *  ds0  =  (float         * )((char         * ) dst -> data     +  i1 * dst -> nb [1 ]);
12532+         const   float  *  s0   =  (const   float  * )((const   char  * ) src0f -> data  +  i1 * src0f -> nb [1 ]);
12533+         const   float  *  s1   =  (const   float  * )((const   char  * ) src1f -> data  +  i1 * src1f -> nb [1 ]);
1251712534
1251812535#ifndef  NDEBUG 
1251912536        for  (int64_t  i  =  0 ; i  <  nc ; ++ i ) {
@@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
1252612543        // soft_max 
1252712544        float  max  =  - INFINITY ;
1252812545        ggml_vec_max_f32 (nc , & max , s0 );
12529-         ggml_float  sum  =  ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
12546+         const   ggml_float  sum  =  ggml_vec_soft_max_f32 (nc , ds0 , s0 , max );
1253012547        assert (sum  >  0.0 );
1253112548        ggml_vec_scale_f32 (nc , ds0 , 1.0 /sum );
1253212549
12533-         // grad(src0 ) = (softmax(src0 ) - src1 ) * grad(cross_entropy_loss(src0, src1 )) / nr 
12550+         // grad(src0f ) = (softmax(src0f ) - src1f ) * grad(cross_entropy_loss(src0f, src1f )) / nr 
1253412551        ggml_vec_sub_f32 (nc , ds0 , ds0 , s1 );
1253512552        ggml_vec_scale_f32 (nc , ds0 , d_by_nr );
1253612553
@@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1282712844            } break ;
1282812845        case  GGML_OP_SOFT_MAX_BACK :
1282912846            {
12830-                 ggml_compute_forward_soft_max_back (params , tensor );
12847+                 ggml_compute_forward_soft_max_ext_back (params , tensor );
1283112848            } break ;
1283212849        case  GGML_OP_ROPE :
1283312850            {
0 commit comments