7
7
* @LastEditTime : 2024-08-07 09:47:43
8
8
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
9
9
**/
10
- #ifndef CPUINFER_CPUINFER_H
11
- #define CPUINFER_CPUINFER_H
12
-
13
- #include < atomic>
14
- #include < condition_variable>
15
- #include < functional>
16
- #include < mutex>
17
- #include < queue>
18
- #include < thread>
19
- #include < vector>
20
- #ifdef KTRANSFORMERS_USE_CUDA
21
- #include " vendors/cuda.h"
22
- #elif KTRANSFORMERS_USE_MUSA
23
- #include " vendors/musa.h"
24
- #endif
25
-
26
- #include " backend.h"
27
- #include " task_queue.h"
28
-
29
- #include " llama.cpp/ggml-impl.h"
30
-
31
- class CPUInfer {
32
- public:
33
- CPUInfer (int thread_num) {
34
- backend_ = new Backend (thread_num - 1 );
35
- task_queue_ = new TaskQueue ();
36
- for (int i = 0 ; i < (1 << 16 ); ++i) {
37
- ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32 (i);
38
- }
39
- }
40
-
41
- ~CPUInfer () {
42
- delete backend_;
43
- delete task_queue_;
44
- }
45
-
46
- template <typename Func, typename Obj, typename ... Args>
47
- void enqueue (Func f, Obj* obj, Args... args) {
48
- task_queue_->enqueue ([=]() {
49
- std::invoke (f, *obj, args..., backend_);
50
- });
51
- }
52
-
53
- void submit (std::pair<intptr_t , intptr_t > params) {
54
- void (*func)(void *) = (void (*)(void *))params.first ;
55
- void * args = (void *)params.second ;
56
- *((CPUInfer**)args) = this ;
57
- func (args);
58
- }
59
-
60
- void sync () {
61
- task_queue_->sync ();
62
- }
63
-
64
- void submit_with_cuda_stream (intptr_t user_cuda_stream, std::pair<intptr_t , intptr_t > params) {
65
- void (*func)(void *) = (void (*)(void *))params.first ;
66
- void * args = (void *)params.second ;
67
- *((CPUInfer**)args) = this ;
68
- cudaLaunchHostFunc ((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
69
- }
70
-
71
- static void sync_ (void * cpu_infer_ptr) {
72
- CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
73
- cpuinfer->sync ();
74
- }
75
-
76
- void sync_with_cuda_stream (intptr_t user_cuda_stream) {
77
- cudaLaunchHostFunc ((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void *)this );
78
- }
79
-
80
- public:
81
- Backend* backend_;
82
- TaskQueue* task_queue_;
83
- };
84
-
85
- #endif
10
+ #ifndef CPUINFER_CPUINFER_H
11
+ #define CPUINFER_CPUINFER_H
12
+
13
+ #include < atomic>
14
+ #include < condition_variable>
15
+ #include < functional>
16
+ #include < mutex>
17
+ #include < queue>
18
+ #include < thread>
19
+ #include < vector>
20
+ #ifdef KTRANSFORMERS_USE_CUDA
21
+ #include " vendors/cuda.h"
22
+ #elif KTRANSFORMERS_USE_MUSA
23
+ #include " vendors/musa.h"
24
+ #elif KTRANSFORMERS_USE_ROCM
25
+ #define __HIP_PLATFORM_AMD__
26
+ #include " vendors/hip.h"
27
+ #endif
28
+
29
+ #include " backend.h"
30
+ #include " task_queue.h"
31
+ #include " ../vendors/vendor.h"
32
+
33
+ #include " llama.cpp/ggml-impl.h"
34
+
35
+ class CPUInfer {
36
+ public:
37
+ CPUInfer (int thread_num) {
38
+ backend_ = new Backend (thread_num - 1 );
39
+ task_queue_ = new TaskQueue ();
40
+ for (int i = 0 ; i < (1 << 16 ); ++i) {
41
+ ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32 (i);
42
+ }
43
+ }
44
+
45
+ ~CPUInfer () {
46
+ delete backend_;
47
+ delete task_queue_;
48
+ }
49
+
50
+ template <typename Func, typename Obj, typename ... Args>
51
+ void enqueue (Func f, Obj* obj, Args... args) {
52
+ task_queue_->enqueue ([=]() {
53
+ std::invoke (f, *obj, args..., backend_);
54
+ });
55
+ }
56
+
57
+ void submit (std::pair<intptr_t , intptr_t > params) {
58
+ void (*func)(void *) = (void (*)(void *))params.first ;
59
+ void * args = (void *)params.second ;
60
+ *((CPUInfer**)args) = this ;
61
+ func (args);
62
+ }
63
+
64
+ void sync () {
65
+ task_queue_->sync ();
66
+ }
67
+
68
+ void submit_with_cuda_stream (intptr_t user_cuda_stream, std::pair<intptr_t , intptr_t > params) {
69
+ void (*func)(void *) = (void (*)(void *))params.first ;
70
+ void * args = (void *)params.second ;
71
+ *((CPUInfer**)args) = this ;
72
+ cudaLaunchHostFunc ((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
73
+ }
74
+
75
+ static void sync_ (void * cpu_infer_ptr) {
76
+ CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
77
+ cpuinfer->sync ();
78
+ }
79
+
80
+ void sync_with_cuda_stream (intptr_t user_cuda_stream) {
81
+ cudaLaunchHostFunc ((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void *)this );
82
+ }
83
+
84
+ public:
85
+ Backend* backend_;
86
+ TaskQueue* task_queue_;
87
+ };
88
+
89
+ #endif
0 commit comments