Skip to content

Commit b505131

Browse files
ebeyabrahameabraham-1ggerganovslaren
authored andcommitted
llama : add phi-2 + fix NeoX rope + ggml_mul_mat_set_prec (ggml-org#4490)
* phi2 implementation * fix breaking change * phi-2 : various fixes * phi-2 : use layer norm eps * py : whitespaces * llama : fix meta KV override bug * convert : phi don't add BOS token * convert : revert "added_tokens_decoder" change * phi-2 : scale Q instead of KQ for better precision * ggml : fix NeoX rope to rotate just first n_dims * cuda : less diff in the rope_neox kernel * ggml : add ggml_mul_mat_set_prec ggml-ci * Update ggml-cuda.cu Co-authored-by: slaren <[email protected]> * Update ggml-cuda.cu Co-authored-by: slaren <[email protected]> * cuda : ggml_cuda_op_mul_mat_cublas support F32 precision * cuda : remove oboslete comment --------- Co-authored-by: Ebey Abraham <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent a2616a1 commit b505131

9 files changed

+463
-76
lines changed

convert-hf-to-gguf.py

+22
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
182182
return QwenModel
183183
if model_architecture == "MixtralForCausalLM":
184184
return MixtralModel
185+
if model_architecture == "PhiForCausalLM":
186+
return Phi2Model
185187
return Model
186188

187189
def _is_model_safetensors(self) -> bool:
@@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
221223
return gguf.MODEL_ARCH.QWEN
222224
if arch == "MixtralForCausalLM":
223225
return gguf.MODEL_ARCH.LLAMA
226+
if arch == "PhiForCausalLM":
227+
return gguf.MODEL_ARCH.PHI2
224228

225229
raise NotImplementedError(f'Architecture "{arch}" not supported!')
226230

@@ -980,6 +984,24 @@ def write_tensors(self):
980984
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
981985
self.gguf_writer.add_tensor(new_name, data)
982986

987+
988+
class Phi2Model(Model):
989+
def set_gguf_parameters(self):
990+
block_count = self.hparams["n_layer"]
991+
992+
self.gguf_writer.add_name("Phi2")
993+
self.gguf_writer.add_context_length(self.hparams["n_positions"])
994+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
995+
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
996+
self.gguf_writer.add_block_count(block_count)
997+
self.gguf_writer.add_head_count(self.hparams["n_head"])
998+
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
999+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
1000+
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
1001+
self.gguf_writer.add_file_type(self.ftype)
1002+
self.gguf_writer.add_add_bos_token(False)
1003+
1004+
9831005
###### CONVERSION LOGIC ######
9841006

9851007

ggml-cuda.cu

+81-36
Original file line numberDiff line numberDiff line change
@@ -4998,7 +4998,16 @@ static __global__ void rope_neox(
49984998
const int ib = col / n_dims;
49994999
const int ic = col % n_dims;
50005000

5001-
const int i = row*ncols + ib*n_dims + ic/2;
5001+
if (ib > 0) {
5002+
const int i = row*ncols + ib*n_dims + ic;
5003+
5004+
dst[i + 0] = x[i + 0];
5005+
dst[i + 1] = x[i + 1];
5006+
5007+
return;
5008+
}
5009+
5010+
const int i = row*ncols + ib*n_dims + ic/2;
50025011
const int i2 = row/p_delta_rows;
50035012

50045013
float cur_rot = inv_ndims * ic - ib;
@@ -7057,6 +7066,7 @@ inline void ggml_cuda_op_upscale(
70577066

70587067
(void) src1;
70597068
(void) dst;
7069+
(void) src1_dd;
70607070
}
70617071

70627072
inline void ggml_cuda_op_pad(
@@ -7073,6 +7083,7 @@ inline void ggml_cuda_op_pad(
70737083

70747084
(void) src1;
70757085
(void) dst;
7086+
(void) src1_dd;
70767087
}
70777088

70787089
inline void ggml_cuda_op_rms_norm(
@@ -7376,7 +7387,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
73767387

73777388
const int compute_capability = g_compute_capabilities[id];
73787389

7379-
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
7390+
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
73807391
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
73817392
half * src0_as_f16 = nullptr;
73827393
size_t src0_as = 0;
@@ -8300,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
83008311
}
83018312

83028313
static __global__ void k_compute_batched_ptrs(
8303-
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
8314+
const half * src0_as_f16, const half * src1_as_f16, char * dst,
83048315
const void ** ptrs_src, void ** ptrs_dst,
8305-
int ne12, int ne13,
8306-
int ne23,
8307-
int nb02, int nb03,
8308-
int nb12, int nb13,
8309-
int nb2, int nb3,
8310-
int r2, int r3) {
8311-
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
8312-
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
8316+
int64_t ne12, int64_t ne13,
8317+
int64_t ne23,
8318+
size_t nb02, size_t nb03,
8319+
size_t nb12, size_t nb13,
8320+
size_t nbd2, size_t nbd3,
8321+
int64_t r2, int64_t r3) {
8322+
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
8323+
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
83138324

83148325
if (i13 >= ne13 || i12 >= ne12) {
83158326
return;
83168327
}
83178328

8318-
int i03 = i13 / r3;
8319-
int i02 = i12 / r2;
8329+
int64_t i03 = i13 / r3;
8330+
int64_t i02 = i12 / r2;
83208331

83218332
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
83228333
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
8323-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
8334+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
83248335
}
83258336

83268337
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8376,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
83768387
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
83778388

83788389
size_t dst_as = 0;
8379-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
8390+
8391+
half * dst_f16 = nullptr;
8392+
char * dst_t = nullptr;
8393+
8394+
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
8395+
cudaDataType_t cu_data_type = CUDA_R_16F;
8396+
8397+
// dst strides
8398+
size_t nbd2 = dst->nb[2];
8399+
size_t nbd3 = dst->nb[3];
8400+
8401+
const half alpha_f16 = 1.0f;
8402+
const half beta_f16 = 0.0f;
8403+
8404+
const float alpha_f32 = 1.0f;
8405+
const float beta_f32 = 0.0f;
8406+
8407+
const void * alpha = &alpha_f16;
8408+
const void * beta = &beta_f16;
8409+
8410+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
8411+
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
8412+
dst_t = (char *) dst_f16;
8413+
8414+
nbd2 /= sizeof(float) / sizeof(half);
8415+
nbd3 /= sizeof(float) / sizeof(half);
8416+
} else {
8417+
dst_t = (char *) dst_ddf;
8418+
8419+
cu_compute_type = CUBLAS_COMPUTE_32F;
8420+
cu_data_type = CUDA_R_32F;
8421+
8422+
alpha = &alpha_f32;
8423+
beta = &beta_f32;
8424+
}
83808425

83818426
GGML_ASSERT(ne12 % ne02 == 0);
83828427
GGML_ASSERT(ne13 % ne03 == 0);
@@ -8385,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
83858430
const int64_t r2 = ne12/ne02;
83868431
const int64_t r3 = ne13/ne03;
83878432

8388-
const half alpha_f16 = 1.0f;
8389-
const half beta_f16 = 0.0f;
8390-
83918433
#if 0
83928434
// use cublasGemmEx
83938435
{
@@ -8397,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
83978439
int i02 = i12 / r2;
83988440

83998441
CUBLAS_CHECK(
8400-
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
8442+
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
84018443
ne01, ne11, ne10,
8402-
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
8403-
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
8404-
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
8405-
CUBLAS_COMPUTE_16F,
8444+
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
8445+
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
8446+
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
8447+
cu_compute_type,
84068448
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
84078449
}
84088450
}
@@ -8414,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
84148456
CUBLAS_CHECK(
84158457
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
84168458
ne01, ne11, ne10,
8417-
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
8418-
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
8419-
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
8459+
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
8460+
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
8461+
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
84208462
ne12*ne13,
8421-
CUBLAS_COMPUTE_16F,
8463+
cu_compute_type,
84228464
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
84238465
} else {
84248466
// use cublasGemmBatchedEx
@@ -8435,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
84358477

84368478
dim3 block_dims(ne13, ne12);
84378479
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
8438-
src0_as_f16, src1_as_f16, dst_f16,
8480+
src0_as_f16, src1_as_f16, dst_t,
84398481
ptrs_src, ptrs_dst,
84408482
ne12, ne13,
84418483
ne23,
84428484
nb02, nb03,
84438485
nb12, nb13,
8444-
dst->nb[2], dst->nb[3],
8486+
nbd2, nbd3,
84458487
r2, r3);
84468488
CUDA_CHECK(cudaGetLastError());
84478489

84488490
CUBLAS_CHECK(
84498491
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
84508492
ne01, ne11, ne10,
8451-
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
8452-
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
8453-
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
8493+
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
8494+
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
8495+
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
84548496
ne23,
8455-
CUBLAS_COMPUTE_16F,
8497+
cu_compute_type,
84568498
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
84578499

84588500
if (ptrs_src_s != 0) {
@@ -8464,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
84648506
}
84658507
#endif
84668508

8467-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
8468-
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
8509+
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
8510+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
8511+
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
8512+
8513+
ggml_cuda_pool_free(dst_f16, dst_as);
8514+
}
84698515

84708516
ggml_cuda_pool_free(src1_as_f16, src1_as);
8471-
ggml_cuda_pool_free(dst_f16, dst_as);
84728517
}
84738518

84748519
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

ggml-metal.metal

+11-2
Original file line numberDiff line numberDiff line change
@@ -1702,8 +1702,9 @@ kernel void kernel_rope(
17021702
dst_data[1] = x0*sin_theta + x1*cos_theta;
17031703
}
17041704
} else {
1705-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1706-
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1705+
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1706+
if (ic < n_dims) {
1707+
const int64_t ib = 0;
17071708

17081709
// simplified from `(ib * n_dims + ic) * inv_ndims`
17091710
const float cur_rot = inv_ndims*ic - ib;
@@ -1722,6 +1723,14 @@ kernel void kernel_rope(
17221723

17231724
dst_data[0] = x0*cos_theta - x1*sin_theta;
17241725
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1726+
} else {
1727+
const int64_t i0 = ic;
1728+
1729+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1730+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1731+
1732+
dst_data[0] = src[0];
1733+
dst_data[1] = src[1];
17251734
}
17261735
}
17271736
}

ggml.c

+40-6
Original file line numberDiff line numberDiff line change
@@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
40984098
return result;
40994099
}
41004100

4101+
void ggml_mul_mat_set_prec(
4102+
struct ggml_tensor * a,
4103+
enum ggml_prec prec) {
4104+
const int32_t prec_i32 = (int32_t) prec;
4105+
4106+
ggml_set_op_params_i32(a, 0, prec_i32);
4107+
}
4108+
41014109
// ggml_mul_mat_id
41024110

41034111
struct ggml_tensor * ggml_mul_mat_id(
@@ -9168,6 +9176,8 @@ static void ggml_compute_forward_norm_f32(
91689176
float eps;
91699177
memcpy(&eps, dst->op_params, sizeof(float));
91709178

9179+
GGML_ASSERT(eps > 0.0f);
9180+
91719181
// TODO: optimize
91729182
for (int64_t i03 = 0; i03 < ne03; i03++) {
91739183
for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -9237,6 +9247,8 @@ static void ggml_compute_forward_rms_norm_f32(
92379247
float eps;
92389248
memcpy(&eps, dst->op_params, sizeof(float));
92399249

9250+
GGML_ASSERT(eps > 0.0f);
9251+
92409252
// TODO: optimize
92419253
for (int64_t i03 = 0; i03 < ne03; i03++) {
92429254
for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -11562,10 +11574,13 @@ static void ggml_compute_forward_rope_f32(
1156211574
}
1156311575
} else {
1156411576
// TODO: this might be wrong for ne0 != n_dims - need double check
11565-
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11577+
// it seems we have to rope just the first n_dims elements and do nothing with the rest
11578+
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
1156611579
theta_base *= freq_scale;
11567-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11568-
for (int64_t ic = 0; ic < n_dims; ic += 2) {
11580+
for (int64_t ic = 0; ic < ne0; ic += 2) {
11581+
if (ic < n_dims) {
11582+
const int64_t ib = 0;
11583+
1156911584
// simplified from `(ib * n_dims + ic) * inv_ndims`
1157011585
float cur_rot = inv_ndims * ic - ib;
1157111586

@@ -11588,6 +11603,14 @@ static void ggml_compute_forward_rope_f32(
1158811603

1158911604
dst_data[0] = x0*cos_theta - x1*sin_theta;
1159011605
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
11606+
} else {
11607+
const int64_t i0 = ic;
11608+
11609+
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11610+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11611+
11612+
dst_data[0] = src[0];
11613+
dst_data[1] = src[1];
1159111614
}
1159211615
}
1159311616
}
@@ -11715,10 +11738,13 @@ static void ggml_compute_forward_rope_f16(
1171511738
}
1171611739
} else {
1171711740
// TODO: this might be wrong for ne0 != n_dims - need double check
11718-
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11741+
// it seems we have to rope just the first n_dims elements and do nothing with the rest
11742+
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
1171911743
theta_base *= freq_scale;
11720-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11721-
for (int64_t ic = 0; ic < n_dims; ic += 2) {
11744+
for (int64_t ic = 0; ic < ne0; ic += 2) {
11745+
if (ic < n_dims) {
11746+
const int64_t ib = 0;
11747+
1172211748
// simplified from `(ib * n_dims + ic) * inv_ndims`
1172311749
float cur_rot = inv_ndims * ic - ib;
1172411750

@@ -11741,6 +11767,14 @@ static void ggml_compute_forward_rope_f16(
1174111767

1174211768
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
1174311769
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11770+
} else {
11771+
const int64_t i0 = ic;
11772+
11773+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11774+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11775+
11776+
dst_data[0] = src[0];
11777+
dst_data[1] = src[1];
1174411778
}
1174511779
}
1174611780
}

0 commit comments

Comments
 (0)