@@ -224,48 +224,47 @@ class WeightKBlockNInteger {
224
224
int rawnk_scale = utils::updiv (K, stor->mBlockSize );
225
225
int nk_scale = utils::updiv (stor->mKPad , stor->mBlockSize );
226
226
parallel::Scheduler2D _para ({threading->num_threads (), 1 , nk_scale, 1 , 1 });
227
- if (stor->SDtype () == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy
227
+ if (stor->SDtype () == BTLA_DTYPE::BF16 || stor-> SDtype () == BTLA_DTYPE::F16 || stor-> SDtype () == BTLA_DTYPE::F32) {
228
228
threading->parallel_for ([&](int tidx) {
229
229
parallel::ThreadProblem2D thdp{tidx};
230
230
_para.getIndex (thdp);
231
231
if (thdp.valid ) {
232
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
233
- if (i < rawnk_scale) {
234
- if (scales != nullptr )
235
- std::memcpy (stor->template SPtr <float >() + i * stor->mNPad , scales + i * N, N * sizeof (scales[0 ]));
236
- if (zero_points != nullptr )
237
- std::memcpy (stor->template ZPtr <int8_t >() + i * stor->mNPad , zero_points + i * N,
238
- N * sizeof (zero_points[0 ]));
239
- } else {
240
- if (scales != nullptr )
241
- std::memset (stor->template SPtr <float >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (float ));
242
- if (zero_points != nullptr )
243
- std::memset (stor->template ZPtr <int8_t >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (zero_points[0 ]));
232
+ int rows = thdp.loc [1 ] + thdp.size [1 ] <= rawnk_scale ? thdp.size [1 ] : rawnk_scale - thdp.loc [1 ];
233
+ if (scales) {
234
+ if (stor->SDtype () == BTLA_DTYPE::BF16) {
235
+ kernel::wrapper::Memcpy2DFp32TPadding<utils::bf16 >::forward_auto (
236
+ scales + thdp.loc [1 ] * N, stor->template SPtr <utils::bf16 >() + thdp.loc [1 ] * stor->mNPad , rows, N,
237
+ N * sizeof (scales[0 ]), stor->mNPad * sizeof (utils::bf16 ), true );
238
+ } else if (stor->SDtype () == BTLA_DTYPE::F32) {
239
+ kernel::wrapper::Memcpy2DPadding::forward (
240
+ scales + thdp.loc [1 ] * N, stor->template SPtr <float >() + thdp.loc [1 ] * stor->mNPad , rows,
241
+ N * sizeof (float ), N * sizeof (scales[0 ]), stor->mNPad * sizeof (float ), true );
242
+ } else if (stor->SDtype () == BTLA_DTYPE::F16) {
243
+ kernel::wrapper::Memcpy2DFp32TPadding<utils::fp16>::forward_auto (
244
+ scales + thdp.loc [1 ] * N, stor->template SPtr <utils::fp16>() + thdp.loc [1 ] * stor->mNPad , rows, N,
245
+ N * sizeof (scales[0 ]), stor->mNPad * sizeof (utils::fp16), true );
244
246
}
245
- }
246
- }
247
- });
248
- } else if (stor->SDtype () == BTLA_DTYPE::BF16) {
249
- threading->parallel_for ([&](int tidx) {
250
- parallel::ThreadProblem2D thdp{tidx};
251
- _para.getIndex (thdp);
252
- if (thdp.valid ) {
253
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
254
- if (i < rawnk_scale) {
255
- if (scales != nullptr ) {
256
- for (size_t j = 0 ; j < N; j++) {
257
- stor->template SPtr <utils::bf16 >()[j + i * stor->mNPad ] = static_cast <utils::bf16 >(scales[i * N + j]);
258
- }
259
- }
260
- if (zero_points != nullptr ) {
261
- std::memcpy (stor->template ZPtr <int8_t >() + i * stor->mNPad , zero_points + i * N,
262
- N * sizeof (zero_points[0 ]));
247
+ if (rows < thdp.size [1 ]) {
248
+ auto sb = bestla::utils::bestla_dtype_bytes (stor->SDtype ());
249
+ if (sb == 2 ) {
250
+ std::memset (stor->template SPtr <utils::fp16>() + (thdp.loc [1 ] + rows) * stor->mNPad , 0 ,
251
+ sb * (thdp.size [1 ] - rows) * stor->mNPad );
252
+ } else if (sb == 4 ) {
253
+ std::memset (stor->template SPtr <float >() + (thdp.loc [1 ] + rows) * stor->mNPad , 0 ,
254
+ sb * (thdp.size [1 ] - rows) * stor->mNPad );
255
+ } else {
256
+ assert (0 );
263
257
}
264
- } else {
265
- if (scales != nullptr )
266
- std::memset (stor->template SPtr <utils::bf16 >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (utils::bf16 ));
267
- if (zero_points != nullptr )
268
- std::memset (stor->template ZPtr <int8_t >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (zero_points[0 ]));
258
+ }
259
+ }
260
+ if (zero_points) {
261
+ kernel::wrapper::Memcpy2DPadding::forward (
262
+ zero_points + thdp.loc [1 ] * N, stor->template ZPtr <int8_t >() + thdp.loc [1 ] * stor->mNPad , rows,
263
+ N * sizeof (zero_points[0 ]), N * sizeof (zero_points[0 ]), sizeof (int8_t ) * stor->mNPad , true );
264
+
265
+ if (rows < thdp.size [1 ]) {
266
+ std::memset (stor->template ZPtr <int8_t >() + (thdp.loc [1 ] + rows) * stor->mNPad , 0 ,
267
+ sizeof (int8_t ) * (thdp.size [1 ] - rows) * stor->mNPad );
269
268
}
270
269
}
271
270
}
@@ -334,84 +333,24 @@ class WeightKBlockNInteger {
334
333
utils::afree (countptr);
335
334
}
336
335
337
- AUTOCALL void setTransposeQuantCorrection (const int N, const int K, const int8_t * zero_points , const float * scales ,
336
+ AUTOCALL void setTransposeQuantCorrection (const int N, const int K, const int8_t * zero_pointsT , const float * scalesT ,
338
337
StorageWeight* stor, parallel::IThreading* threading) {
339
338
int rawnk_scale = utils::updiv (K, stor->mBlockSize );
340
- int nk_scale = utils::updiv (stor->mKPad , stor->mBlockSize );
341
- parallel::Scheduler2D _para ({threading->num_threads (), 1 , nk_scale, 1 , 1 });
342
- if (stor->SDtype () == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy
343
- threading->parallel_for ([&](int tidx) {
344
- parallel::ThreadProblem2D thdp{tidx};
345
- _para.getIndex (thdp);
346
- if (thdp.valid ) {
347
- if (scales) {
348
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
349
- if (i < rawnk_scale) {
350
- for (int j = 0 ; j < N; j++) {
351
- stor->template SPtr <float >()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
352
- }
353
- } else {
354
- std::memset (stor->template SPtr <float >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (float ));
355
- }
356
- }
357
- }
358
- }
359
- });
360
- } else if (stor->SDtype () == BTLA_DTYPE::BF16) {
361
- threading->parallel_for ([&](int tidx) {
362
- parallel::ThreadProblem2D thdp{tidx};
363
- _para.getIndex (thdp);
364
- if (thdp.valid ) {
365
- if (scales) {
366
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
367
- if (i < rawnk_scale) {
368
- for (int j = 0 ; j < N; j++) {
369
- stor->template SPtr <utils::bf16 >()[i * stor->mNPad + j] = utils::bf16 (scales[j * rawnk_scale + i]);
370
- }
371
- } else {
372
- std::memset (stor->template SPtr <utils::bf16 >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (utils::bf16 ));
373
- }
374
- }
375
- }
376
- }
377
- });
378
- } else if (stor->SDtype () == BTLA_DTYPE::F8_E8M0) {
379
- threading->parallel_for ([&](int tidx) {
380
- parallel::ThreadProblem2D thdp{tidx};
381
- _para.getIndex (thdp);
382
- if (thdp.valid ) {
383
- if (scales) {
384
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
385
- if (i < rawnk_scale) {
386
- for (int j = 0 ; j < N; j++) {
387
- stor->template SPtr <utils::f8 >()[i * stor->mNPad + j] = static_cast <int >(scales[j * rawnk_scale + i]);
388
- }
389
- } else {
390
- std::memset (stor->template SPtr <utils::f8 >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (utils::f8 ));
391
- }
392
- }
393
- }
394
- }
395
- });
396
- } else {
397
- assert (0 );
339
+ auto scales = scalesT ? utils::amalloc<float >(rawnk_scale * N) : nullptr ;
340
+ auto zero_points = zero_pointsT ? utils::amalloc<int8_t >(rawnk_scale * N) : nullptr ;
341
+ if (scales) {
342
+ transposeWeight<float >(N, rawnk_scale, scalesT, rawnk_scale, scales, N, threading);
343
+ }
344
+ if (zero_points) {
345
+ transposeWeight<int8_t >(N, rawnk_scale, zero_pointsT, rawnk_scale, zero_points, N, threading);
346
+ }
347
+ setQuantCorrection (N, K, zero_points, scales, stor, threading);
348
+ if (scales) {
349
+ utils::afree (scales);
350
+ }
351
+ if (zero_points) {
352
+ utils::afree (zero_points);
398
353
}
399
- if (stor->IsAsym () && zero_points)
400
- threading->parallel_for ([&](int tidx) {
401
- parallel::ThreadProblem2D thdp{tidx};
402
- _para.getIndex (thdp);
403
- if (thdp.valid ) {
404
- for (int i = thdp.loc [1 ]; i < thdp.loc [1 ] + thdp.size [1 ]; i++) {
405
- if (i < rawnk_scale) {
406
- for (int j = 0 ; j < N; j++) {
407
- stor->template ZPtr <int8_t >()[i * stor->mNPad + j] = zero_points[j * rawnk_scale + i];
408
- }
409
- } else {
410
- std::memset (stor->template ZPtr <int8_t >() + i * stor->mNPad , 0 , stor->mNPad * sizeof (zero_points[0 ]));
411
- }
412
- }
413
- }
414
- });
415
354
}
416
355
417
356
AUTOCALL void packQWeight (const int N, const int K, const int8_t * B, const int ldb, const float * scales,
@@ -445,6 +384,7 @@ class WeightKBlockNInteger {
445
384
auto blks_padding2 = utils::padto (blks, 2 );
446
385
auto tmpscales = tmp;
447
386
auto tmpzeropoints = reinterpret_cast <int8_t *>(tmpscales + N * blks);
387
+ assert (isasym == (zero_points != nullptr ));
448
388
if (scales) {
449
389
for (size_t i = 0 ; i < N * blks; i += 1 ) {
450
390
tmpscales[i] = scales[i];
@@ -640,6 +580,7 @@ class WeightKBlockNInteger {
640
580
}
641
581
});
642
582
}
583
+
643
584
AUTOCALL void compressWeight (const int N, const int K, const int8_t * B, const int ldb, int8_t * dstptr,
644
585
BTLA_DTYPE qtype, parallel::IThreading* threading) {
645
586
if (qtype == BTLA_DTYPE::S7_CLIP) return compressBit7Weight (N, K, B, dstptr, qtype, threading);
@@ -726,6 +667,13 @@ class WeightKBlockNInteger {
726
667
utils::updiv (k_size, wptr->mBlockSize ), n_size, wptr->CStep () * 2 , n_size * 4 , false );
727
668
*dststep = n_size;
728
669
}
670
+ if (wptr->SDtype () == BTLA_DTYPE::F16) {
671
+ auto aptr = wptr->template SPtr <utils::fp16>();
672
+ kernel::wrapper::Memcpy2DFp16CvtFp32::forward<ISA_T>(
673
+ aptr + k_offset / wptr->mBlockSize * wptr->CStep () + n_offset, *dstptr,
674
+ utils::updiv (k_size, wptr->mBlockSize ), n_size, wptr->CStep () * 2 , n_size * 4 , false );
675
+ *dststep = n_size;
676
+ }
729
677
if (wptr->SDtype () == BTLA_DTYPE::DQ8_BNB) {
730
678
auto aptr = wptr->template SPtr <uint8_t >();
731
679
auto internal_k_offset = k_offset / wptr->mBlockSize ;
0 commit comments