Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 9652017

Browse files
authored
[BesTLA] Support fp16 for compute_dtype and scale_dtype (#292)
* split functions by isa * compiled * remove all intrinsics code from mha * add header * update target for ICX * add diagnostic back * update header * type conversion * fix * remove function for gcc11 * add fp16 conversion for avx2 * use one template function instead * add scale_dtype=fp16 * add comp_fp16 UT for low bits * fix warning * support f16 for quant api * add template for comp_fp16 * remove avx512_bf16 templates * fix gcc version
1 parent 4f645cf commit 9652017

27 files changed

+5060
-4398
lines changed

.github/workflows/unit-test-bestla.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ on:
1313
description: 'compiler_version'
1414
required: false
1515
type: string
16-
default: '12.1.0'
16+
default: '13.2.0'
1717

1818
# If there is a new commit, the previous jobs will be canceled
1919
concurrency:
2020
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
2121
cancel-in-progress: true
2222

2323
env:
24-
INPUT_COMPILER_VERSION: ${{ inputs.compiler_version || '12.1.0' }}
24+
INPUT_COMPILER_VERSION: ${{ inputs.compiler_version || '13.2.0' }}
2525
WORKING_DIR: ${{ github.workspace }}
2626
CONTAINER_NAME: "utTest"
2727

bestla/bestla/bestla.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ enum class BTLA_ISA : uint8_t {
2828
AVX512F,
2929
AVX512BW,
3030
AVX512_VNNI,
31+
AVX512_BF16,
32+
AVX512_FP16,
3133
AMX_BF16,
3234
AMX_INT8,
33-
AVX512_FP16,
34-
AVX512_BF16,
3535
AMX_FP16,
3636
ISA_COUNT,
3737
};

bestla/bestla/bestla_parallel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#if BTLA_OPENMP
2020
#include <omp.h>
2121
#endif
22+
#include <immintrin.h>
2223
#include "bestla_utils.h"
2324
#include "bestla_device.h"
2425

bestla/bestla/bestla_prologue_b.h

Lines changed: 59 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -224,48 +224,47 @@ class WeightKBlockNInteger {
224224
int rawnk_scale = utils::updiv(K, stor->mBlockSize);
225225
int nk_scale = utils::updiv(stor->mKPad, stor->mBlockSize);
226226
parallel::Scheduler2D _para({threading->num_threads(), 1, nk_scale, 1, 1});
227-
if (stor->SDtype() == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy
227+
if (stor->SDtype() == BTLA_DTYPE::BF16 || stor->SDtype() == BTLA_DTYPE::F16 || stor->SDtype() == BTLA_DTYPE::F32) {
228228
threading->parallel_for([&](int tidx) {
229229
parallel::ThreadProblem2D thdp{tidx};
230230
_para.getIndex(thdp);
231231
if (thdp.valid) {
232-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
233-
if (i < rawnk_scale) {
234-
if (scales != nullptr)
235-
std::memcpy(stor->template SPtr<float>() + i * stor->mNPad, scales + i * N, N * sizeof(scales[0]));
236-
if (zero_points != nullptr)
237-
std::memcpy(stor->template ZPtr<int8_t>() + i * stor->mNPad, zero_points + i * N,
238-
N * sizeof(zero_points[0]));
239-
} else {
240-
if (scales != nullptr)
241-
std::memset(stor->template SPtr<float>() + i * stor->mNPad, 0, stor->mNPad * sizeof(float));
242-
if (zero_points != nullptr)
243-
std::memset(stor->template ZPtr<int8_t>() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0]));
232+
int rows = thdp.loc[1] + thdp.size[1] <= rawnk_scale ? thdp.size[1] : rawnk_scale - thdp.loc[1];
233+
if (scales) {
234+
if (stor->SDtype() == BTLA_DTYPE::BF16) {
235+
kernel::wrapper::Memcpy2DFp32TPadding<utils::bf16>::forward_auto(
236+
scales + thdp.loc[1] * N, stor->template SPtr<utils::bf16>() + thdp.loc[1] * stor->mNPad, rows, N,
237+
N * sizeof(scales[0]), stor->mNPad * sizeof(utils::bf16), true);
238+
} else if (stor->SDtype() == BTLA_DTYPE::F32) {
239+
kernel::wrapper::Memcpy2DPadding::forward(
240+
scales + thdp.loc[1] * N, stor->template SPtr<float>() + thdp.loc[1] * stor->mNPad, rows,
241+
N * sizeof(float), N * sizeof(scales[0]), stor->mNPad * sizeof(float), true);
242+
} else if (stor->SDtype() == BTLA_DTYPE::F16) {
243+
kernel::wrapper::Memcpy2DFp32TPadding<utils::fp16>::forward_auto(
244+
scales + thdp.loc[1] * N, stor->template SPtr<utils::fp16>() + thdp.loc[1] * stor->mNPad, rows, N,
245+
N * sizeof(scales[0]), stor->mNPad * sizeof(utils::fp16), true);
244246
}
245-
}
246-
}
247-
});
248-
} else if (stor->SDtype() == BTLA_DTYPE::BF16) {
249-
threading->parallel_for([&](int tidx) {
250-
parallel::ThreadProblem2D thdp{tidx};
251-
_para.getIndex(thdp);
252-
if (thdp.valid) {
253-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
254-
if (i < rawnk_scale) {
255-
if (scales != nullptr) {
256-
for (size_t j = 0; j < N; j++) {
257-
stor->template SPtr<utils::bf16>()[j + i * stor->mNPad] = static_cast<utils::bf16>(scales[i * N + j]);
258-
}
259-
}
260-
if (zero_points != nullptr) {
261-
std::memcpy(stor->template ZPtr<int8_t>() + i * stor->mNPad, zero_points + i * N,
262-
N * sizeof(zero_points[0]));
247+
if (rows < thdp.size[1]) {
248+
auto sb = bestla::utils::bestla_dtype_bytes(stor->SDtype());
249+
if (sb == 2) {
250+
std::memset(stor->template SPtr<utils::fp16>() + (thdp.loc[1] + rows) * stor->mNPad, 0,
251+
sb * (thdp.size[1] - rows) * stor->mNPad);
252+
} else if (sb == 4) {
253+
std::memset(stor->template SPtr<float>() + (thdp.loc[1] + rows) * stor->mNPad, 0,
254+
sb * (thdp.size[1] - rows) * stor->mNPad);
255+
} else {
256+
assert(0);
263257
}
264-
} else {
265-
if (scales != nullptr)
266-
std::memset(stor->template SPtr<utils::bf16>() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::bf16));
267-
if (zero_points != nullptr)
268-
std::memset(stor->template ZPtr<int8_t>() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0]));
258+
}
259+
}
260+
if (zero_points) {
261+
kernel::wrapper::Memcpy2DPadding::forward(
262+
zero_points + thdp.loc[1] * N, stor->template ZPtr<int8_t>() + thdp.loc[1] * stor->mNPad, rows,
263+
N * sizeof(zero_points[0]), N * sizeof(zero_points[0]), sizeof(int8_t) * stor->mNPad, true);
264+
265+
if (rows < thdp.size[1]) {
266+
std::memset(stor->template ZPtr<int8_t>() + (thdp.loc[1] + rows) * stor->mNPad, 0,
267+
sizeof(int8_t) * (thdp.size[1] - rows) * stor->mNPad);
269268
}
270269
}
271270
}
@@ -334,84 +333,24 @@ class WeightKBlockNInteger {
334333
utils::afree(countptr);
335334
}
336335

337-
AUTOCALL void setTransposeQuantCorrection(const int N, const int K, const int8_t* zero_points, const float* scales,
336+
AUTOCALL void setTransposeQuantCorrection(const int N, const int K, const int8_t* zero_pointsT, const float* scalesT,
338337
StorageWeight* stor, parallel::IThreading* threading) {
339338
int rawnk_scale = utils::updiv(K, stor->mBlockSize);
340-
int nk_scale = utils::updiv(stor->mKPad, stor->mBlockSize);
341-
parallel::Scheduler2D _para({threading->num_threads(), 1, nk_scale, 1, 1});
342-
if (stor->SDtype() == BTLA_DTYPE::F32) { // fp32 to fp32 direct copy
343-
threading->parallel_for([&](int tidx) {
344-
parallel::ThreadProblem2D thdp{tidx};
345-
_para.getIndex(thdp);
346-
if (thdp.valid) {
347-
if (scales) {
348-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
349-
if (i < rawnk_scale) {
350-
for (int j = 0; j < N; j++) {
351-
stor->template SPtr<float>()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
352-
}
353-
} else {
354-
std::memset(stor->template SPtr<float>() + i * stor->mNPad, 0, stor->mNPad * sizeof(float));
355-
}
356-
}
357-
}
358-
}
359-
});
360-
} else if (stor->SDtype() == BTLA_DTYPE::BF16) {
361-
threading->parallel_for([&](int tidx) {
362-
parallel::ThreadProblem2D thdp{tidx};
363-
_para.getIndex(thdp);
364-
if (thdp.valid) {
365-
if (scales) {
366-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
367-
if (i < rawnk_scale) {
368-
for (int j = 0; j < N; j++) {
369-
stor->template SPtr<utils::bf16>()[i * stor->mNPad + j] = utils::bf16(scales[j * rawnk_scale + i]);
370-
}
371-
} else {
372-
std::memset(stor->template SPtr<utils::bf16>() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::bf16));
373-
}
374-
}
375-
}
376-
}
377-
});
378-
} else if (stor->SDtype() == BTLA_DTYPE::F8_E8M0) {
379-
threading->parallel_for([&](int tidx) {
380-
parallel::ThreadProblem2D thdp{tidx};
381-
_para.getIndex(thdp);
382-
if (thdp.valid) {
383-
if (scales) {
384-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
385-
if (i < rawnk_scale) {
386-
for (int j = 0; j < N; j++) {
387-
stor->template SPtr<utils::f8>()[i * stor->mNPad + j] = static_cast<int>(scales[j * rawnk_scale + i]);
388-
}
389-
} else {
390-
std::memset(stor->template SPtr<utils::f8>() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::f8));
391-
}
392-
}
393-
}
394-
}
395-
});
396-
} else {
397-
assert(0);
339+
auto scales = scalesT ? utils::amalloc<float>(rawnk_scale * N) : nullptr;
340+
auto zero_points = zero_pointsT ? utils::amalloc<int8_t>(rawnk_scale * N) : nullptr;
341+
if (scales) {
342+
transposeWeight<float>(N, rawnk_scale, scalesT, rawnk_scale, scales, N, threading);
343+
}
344+
if (zero_points) {
345+
transposeWeight<int8_t>(N, rawnk_scale, zero_pointsT, rawnk_scale, zero_points, N, threading);
346+
}
347+
setQuantCorrection(N, K, zero_points, scales, stor, threading);
348+
if (scales) {
349+
utils::afree(scales);
350+
}
351+
if (zero_points) {
352+
utils::afree(zero_points);
398353
}
399-
if (stor->IsAsym() && zero_points)
400-
threading->parallel_for([&](int tidx) {
401-
parallel::ThreadProblem2D thdp{tidx};
402-
_para.getIndex(thdp);
403-
if (thdp.valid) {
404-
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
405-
if (i < rawnk_scale) {
406-
for (int j = 0; j < N; j++) {
407-
stor->template ZPtr<int8_t>()[i * stor->mNPad + j] = zero_points[j * rawnk_scale + i];
408-
}
409-
} else {
410-
std::memset(stor->template ZPtr<int8_t>() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0]));
411-
}
412-
}
413-
}
414-
});
415354
}
416355

417356
AUTOCALL void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales,
@@ -445,6 +384,7 @@ class WeightKBlockNInteger {
445384
auto blks_padding2 = utils::padto(blks, 2);
446385
auto tmpscales = tmp;
447386
auto tmpzeropoints = reinterpret_cast<int8_t*>(tmpscales + N * blks);
387+
assert(isasym == (zero_points != nullptr));
448388
if (scales) {
449389
for (size_t i = 0; i < N * blks; i += 1) {
450390
tmpscales[i] = scales[i];
@@ -640,6 +580,7 @@ class WeightKBlockNInteger {
640580
}
641581
});
642582
}
583+
643584
AUTOCALL void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr,
644585
BTLA_DTYPE qtype, parallel::IThreading* threading) {
645586
if (qtype == BTLA_DTYPE::S7_CLIP) return compressBit7Weight(N, K, B, dstptr, qtype, threading);
@@ -726,6 +667,13 @@ class WeightKBlockNInteger {
726667
utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep() * 2, n_size * 4, false);
727668
*dststep = n_size;
728669
}
670+
if (wptr->SDtype() == BTLA_DTYPE::F16) {
671+
auto aptr = wptr->template SPtr<utils::fp16>();
672+
kernel::wrapper::Memcpy2DFp16CvtFp32::forward<ISA_T>(
673+
aptr + k_offset / wptr->mBlockSize * wptr->CStep() + n_offset, *dstptr,
674+
utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep() * 2, n_size * 4, false);
675+
*dststep = n_size;
676+
}
729677
if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
730678
auto aptr = wptr->template SPtr<uint8_t>();
731679
auto internal_k_offset = k_offset / wptr->mBlockSize;

bestla/bestla/bestla_utils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#define BTLA_OPENMP 0
2121
#endif
2222

23+
#define FP32_BF16_FAST 0
24+
2325
#if BTLA_OPENMP
2426
#include <omp.h>
2527
#endif
@@ -83,8 +85,6 @@
8385
// runtime auto-dispatch ISA, not time critical functions
8486
#define AUTOCALL static
8587

86-
#include <immintrin.h>
87-
8888
namespace bestla {
8989
namespace utils {
9090

@@ -388,6 +388,8 @@ inline constexpr size_t bestla_dtype_bits(const BTLA_DTYPE t) {
388388
return bestla_dtype_get_mask_val(t, BTLA_DTYPE::EleBitsMask, BTLA_DTYPE::EleBitsShift);
389389
}
390390

391+
inline constexpr size_t bestla_dtype_bytes(const BTLA_DTYPE t) { return bestla_dtype_bits(t) >> 3; }
392+
391393
inline constexpr size_t bestla_dtype_type(const BTLA_DTYPE t) {
392394
return bestla_dtype_get_mask_val(t, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift);
393395
}
@@ -464,9 +466,11 @@ class isa_base {
464466
static bool constexpr avx2 = ISA_T >= BTLA_ISA::AVX2;
465467
static bool constexpr avx512f = ISA_T >= BTLA_ISA::AVX512F;
466468
static bool constexpr avx512_vnni = ISA_T >= BTLA_ISA::AVX512_VNNI;
469+
static bool constexpr avx512_bf16 = ISA_T >= BTLA_ISA::AVX512_BF16;
467470
static bool constexpr avx512_fp16 = ISA_T >= BTLA_ISA::AVX512_FP16;
468471
static bool constexpr amx_bf16 = ISA_T >= BTLA_ISA::AMX_BF16;
469472
static bool constexpr amx_int8 = ISA_T >= BTLA_ISA::AMX_INT8;
473+
static bool constexpr amx_fp16 = ISA_T >= BTLA_ISA::AMX_FP16;
470474
};
471475

472476
static inline int padto_le(int src, int padding) { return src / padding * padding; }

bestla/bestla/bestla_wrapper.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ class LauncherBase {
353353
_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP;
354354
if constexpr (support()) {
355355
impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 ||
356+
_param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F16 ||
356357
_param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::BF16;
357358
}
358359

@@ -451,6 +452,17 @@ class LauncherBase {
451452
if (m == 7) gemv_kblock<utils::bf16, 7>(_param, _config);
452453
if (m == 8) gemv_kblock<utils::bf16, 8>(_param, _config);
453454
}
455+
} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F16) {
456+
if (m == 1) gemv_kblock<utils::fp16, 1>(_param, _config);
457+
if (m == 2) gemv_kblock<utils::fp16, 2>(_param, _config);
458+
if (m == 3) gemv_kblock<utils::fp16, 3>(_param, _config);
459+
if (m == 4) gemv_kblock<utils::fp16, 4>(_param, _config);
460+
if constexpr (Reg32) {
461+
if (m == 5) gemv_kblock<utils::fp16, 5>(_param, _config);
462+
if (m == 6) gemv_kblock<utils::fp16, 6>(_param, _config);
463+
if (m == 7) gemv_kblock<utils::fp16, 7>(_param, _config);
464+
if (m == 8) gemv_kblock<utils::fp16, 8>(_param, _config);
465+
}
454466
}
455467
}
456468
}
@@ -622,6 +634,7 @@ class LauncherIntKBlock {
622634
_param.paramB.packedW->mDType == BTLA_DTYPE::S1_CLIP ||
623635
_param.paramB.packedW->mDType == BTLA_DTYPE::S2_CLIP;
624636
impl &= _param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F32 ||
637+
_param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::F16 ||
625638
_param.paramB.packedW->mCorrection.mScaT == BTLA_DTYPE::BF16;
626639
impl &= _param.problem.dims[1] <= MaxGemvM;
627640
return impl;
@@ -699,6 +712,17 @@ class LauncherIntKBlock {
699712
if (m == 7) gemv_kblock<utils::bf16, 7>(_param, _config);
700713
if (m == 8) gemv_kblock<utils::bf16, 8>(_param, _config);
701714
}
715+
} else if (_param.paramB.packedW->SDtype() == BTLA_DTYPE::F16) {
716+
if (m == 1) gemv_kblock<utils::fp16, 1>(_param, _config);
717+
if (m == 2) gemv_kblock<utils::fp16, 2>(_param, _config);
718+
if (m == 3) gemv_kblock<utils::fp16, 3>(_param, _config);
719+
if (m == 4) gemv_kblock<utils::fp16, 4>(_param, _config);
720+
if constexpr (Reg32) {
721+
if (m == 5) gemv_kblock<utils::fp16, 5>(_param, _config);
722+
if (m == 6) gemv_kblock<utils::fp16, 6>(_param, _config);
723+
if (m == 7) gemv_kblock<utils::fp16, 7>(_param, _config);
724+
if (m == 8) gemv_kblock<utils::fp16, 8>(_param, _config);
725+
}
702726
}
703727
}
704728
}

0 commit comments

Comments
 (0)