Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit e0e65bd

Browse files
authored
[BesTLA] Simplify the templates (#274)
* remove ISA from prologue_a * fix assert condition * add AUTOCALL * remove ISA from prologueb * remove gemmcore instance. remove runtime ISA * compile with GCC * apply refactor to all kernels * remove warning * compile with gcc, add linux UT preset * clang-format * fix class name of amx * fix UT case * fix clang-tidy * support DQ for NFloat * fix warning * clang-format
1 parent 315df3a commit e0e65bd

26 files changed

+1510
-1472
lines changed

CMakePresets.json

+12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@
3939
"NS_USE_OMP": "OFF"
4040
}
4141
},
42+
{
43+
"name": "linux-release-ut-thread",
44+
"displayName": "Linux Release Thread Pool for UTs",
45+
"description": "Release",
46+
"inherits": "linux-debug",
47+
"cacheVariables": {
48+
"CMAKE_BUILD_TYPE": "Release",
49+
"NS_USE_OMP": "OFF",
50+
"BTLA_UT_ALL": "ON",
51+
"BTLA_UT_BENCHMARK": "ON"
52+
}
53+
},
4254
{
4355
"name": "windows-base",
4456
"description": "Target Windows with the Visual Studio development environment.",

bestla/bestla/bestla_device.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ class CpuDevice {
233233
inline bool AVX512_BF16() { return mHasAVX512_BF16; }
234234
inline bool AVX512_FP16() { return mHasAVX512_FP16; }
235235
inline float* const getPE() { return PE; }
236-
inline size_t getPcoreNum() { return P_core.size(); }
237-
inline size_t getEcoreNum() { return E_core.size(); }
238-
inline size_t getSMTcoreNum() { return SMT_core.size(); }
236+
inline int getPcoreNum() { return static_cast<int>(P_core.size()); }
237+
inline int getEcoreNum() { return static_cast<int>(E_core.size()); }
238+
inline int getSMTcoreNum() { return static_cast<int>(SMT_core.size()); }
239239
inline int* getPCores() { return P_core.data(); }
240240
inline int* getECores() { return E_core.data(); }
241241
inline int* getSMTCores() { return SMT_core.data(); }
@@ -467,15 +467,15 @@ class CpuDevice {
467467
bool isClient() { return mClient; }
468468

469469
protected:
470-
uint32_t L2Cache, L1Cache, L3Cache;
470+
uint32_t L2Cache = 0, L1Cache = 0, L3Cache = 0;
471471
bool mHybrid = false, mClient = false;
472-
bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512BW,
473-
mHasAVX512_BF16, mHasAVX512_FP16;
474-
int numcores;
475-
int numthreads;
472+
bool mHasAVX2 = false, mHasAVX_VNNI = false, mHasAVX = false, mHasAVX512_VNNI = false, mHasAMX_INT8 = false,
473+
mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false;
474+
int numcores = 0;
475+
int numthreads = 0;
476476
std::vector<int> P_core, E_core, SMT_core;
477-
uint32_t E_L2Cache, E_L1Cache;
478-
float PE[int(BTLA_ISA::ISA_COUNT)];
477+
uint32_t E_L2Cache = 0, E_L1Cache = 0;
478+
float PE[int(BTLA_ISA::ISA_COUNT)] = {1.f};
479479
};
480480

481481
#define GetCPUDevice() auto _cd = bestla::device::CpuDevice::getInstance();

bestla/bestla/bestla_gemm.h

+138-101
Large diffs are not rendered by default.

bestla/bestla/bestla_parallel.h

+19-12
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class IThreading {
6868
virtual std::pair<float, float> get_PEtime() const { return {0.0f, 0.0f}; };
6969

7070
protected:
71-
int mThreadNum;
72-
const bool isSupportPE;
71+
int mThreadNum = 0;
72+
const bool isSupportPE = false;
7373
};
7474

7575
#if BTLA_OPENMP
@@ -107,7 +107,14 @@ class OMPThreading : public IThreading {
107107
class StdThreading : public IThreading {
108108
public:
109109
using Timer_T = utils::timer<utils::microseconds>;
110-
explicit StdThreading() : IThreading(true) { cr = nullptr; }
110+
explicit StdThreading() : IThreading(true) {
111+
cr = nullptr;
112+
memset(func_, 0, sizeof(func_));
113+
memset(flag, 0, sizeof(flag));
114+
stop = true;
115+
time_per_p = -1.f;
116+
time_per_e = -1.f;
117+
}
111118

112119
void parallel_for(const thread_func& func) override {
113120
time_per_p = 0;
@@ -117,7 +124,7 @@ class StdThreading : public IThreading {
117124
running.store(mThreadNum - 1);
118125
for (int i = 0; i < 10; i++) flag[i].store(mThreadNum);
119126
if (cr->mHybrid) {
120-
int time_p = 0, time_e = 0;
127+
int64_t time_p = 0, time_e = 0;
121128

122129
for (size_t i = 0; i < mThreadNum - 1; i++) func_[i] = &func;
123130
thread_time[0] = 0;
@@ -135,8 +142,8 @@ class StdThreading : public IThreading {
135142
time_e += thread_time[i];
136143
else
137144
time_p += thread_time[i];
138-
time_per_p = (time_p) / (1.0 * (mThreadNum - cr->E_core_num));
139-
time_per_e = (time_e) / (1.0 * cr->E_core_num);
145+
time_per_p = (time_p) / (1.0f * (mThreadNum - cr->E_core_num));
146+
time_per_e = (time_e) / (1.0f * cr->E_core_num);
140147
// printf("%d %d %f %f\n", time_p, time_e, time_per_p, time_per_e);
141148
} else {
142149
for (size_t i = 0; i < mThreadNum - 1; i++) {
@@ -810,7 +817,7 @@ class SchedulerDispatcher<Scheduler2D> {
810817
} // namespace gemm
811818

812819
template <class Parallel_T, class Launch_T>
813-
void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) {
820+
void GemmRun(const typename Launch_T::Param& args, parallel::IThreading* th) {
814821
gemm::SchedulerDispatcher<Parallel_T> para(th, args.problem);
815822
static bool flag = false;
816823
if (flag) {
@@ -822,16 +829,16 @@ void GemmRun(Launch_T& launcher, const typename Launch_T::Param& args, parallel:
822829
typename Parallel_T::ThreadProblem thdp{tidx};
823830
para.getIndex(thdp);
824831
if (thdp.valid) {
825-
launcher.run(args, thdp);
832+
Launch_T::run(args, thdp);
826833
}
827834
});
828835
}
829836

830837
template <class Parallel_T, class Launch_T>
831-
void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, parallel::IThreading* th) {
838+
void GemmRunWithA(const typename Launch_T::Param& args, parallel::IThreading* th) {
832839
gemm::SchedulerDispatcher<Parallel_T> para(th, args.problem);
833840
using AParall = typename Launch_T::PrologueA::Parallel;
834-
AParall apara = launcher.mProA.createParallel(th->num_threads(), args.problem);
841+
AParall apara = Launch_T::PrologueA::createParallel(th->num_threads(), args.problem);
835842
static bool flag = false;
836843
if (flag) {
837844
printf("%s\n", __FUNCTION__);
@@ -842,13 +849,13 @@ void GemmRunWithA(Launch_T& launcher, const typename Launch_T::Param& args, para
842849
typename AParall::ThreadProblem thdpA{tidx};
843850
apara.getIndex(thdpA);
844851
if (thdpA.valid) {
845-
launcher.mProA.run(args.paramA, thdpA);
852+
Launch_T::PrologueA::run(args.paramA, thdpA);
846853
}
847854
th->sync(tidx);
848855
typename Parallel_T::ThreadProblem thdp{tidx};
849856
para.getIndex(thdp);
850857
if (thdp.valid) {
851-
launcher.run(args, thdp);
858+
Launch_T::run(args, thdp);
852859
}
853860
});
854861
}

0 commit comments

Comments
 (0)