Skip to content

Commit ed84374

Browse files
committed
merge main; Add torch q8 linear
1 parent 6c4ed59 commit ed84374

27 files changed

+1561
-114
lines changed

Diff for: ktransformers/ktransformers_ext/CMakeLists.txt

+39
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ endif()
3232
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
3333
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
3434
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
35+
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
3536

3637
# Architecture specific
3738
# TODO: probably these flags need to be tweaked on some architectures
@@ -201,6 +202,31 @@ endif()
201202
# message(STATUS "Can't found CUDA lib")
202203
# endif()
203204

205+
if (NOT EXISTS $ENV{ROCM_PATH})
206+
if (NOT EXISTS /opt/rocm)
207+
set(ROCM_PATH /usr)
208+
else()
209+
set(ROCM_PATH /opt/rocm)
210+
endif()
211+
else()
212+
set(ROCM_PATH $ENV{ROCM_PATH})
213+
endif()
214+
215+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
216+
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
217+
218+
if (NOT EXISTS $ENV{MUSA_PATH})
219+
if (NOT EXISTS /opt/musa)
220+
set(MUSA_PATH /usr/local/musa)
221+
else()
222+
set(MUSA_PATH /opt/musa)
223+
endif()
224+
else()
225+
set(MUSA_PATH $ENV{MUSA_PATH})
226+
endif()
227+
228+
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
229+
204230
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
205231
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
206232

@@ -218,6 +244,14 @@ elseif (UNIX)
218244
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
219245
endif()
220246

247+
if (KTRANSFORMERS_USE_ROCM)
248+
find_package(HIP REQUIRED)
249+
if(HIP_FOUND)
250+
include_directories("${HIP_INCLUDE_DIRS}")
251+
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
252+
endif()
253+
endif()
254+
221255
if (KTRANSFORMERS_USE_MUSA)
222256
if (NOT EXISTS $ENV{MUSA_PATH})
223257
if (NOT EXISTS /opt/musa)
@@ -258,6 +292,11 @@ elseif(UNIX)
258292
endif()
259293
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
260294
endif()
295+
if (KTRANSFORMERS_USE_ROCM)
296+
add_compile_definitions(USE_HIP=1)
297+
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
298+
message(STATUS "Building for HIP")
299+
endif()
261300
if(KTRANSFORMERS_USE_MUSA)
262301
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
263302
endif()

Diff for: ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h

+80-76
Original file line numberDiff line numberDiff line change
@@ -7,79 +7,83 @@
77
* @LastEditTime : 2024-08-07 09:47:43
88
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
99
**/
10-
#ifndef CPUINFER_CPUINFER_H
11-
#define CPUINFER_CPUINFER_H
12-
13-
#include <atomic>
14-
#include <condition_variable>
15-
#include <functional>
16-
#include <mutex>
17-
#include <queue>
18-
#include <thread>
19-
#include <vector>
20-
#ifdef KTRANSFORMERS_USE_CUDA
21-
#include "vendors/cuda.h"
22-
#elif KTRANSFORMERS_USE_MUSA
23-
#include "vendors/musa.h"
24-
#endif
25-
26-
#include "backend.h"
27-
#include "task_queue.h"
28-
29-
#include "llama.cpp/ggml-impl.h"
30-
31-
class CPUInfer {
32-
public:
33-
CPUInfer(int thread_num) {
34-
backend_ = new Backend(thread_num - 1);
35-
task_queue_ = new TaskQueue();
36-
for (int i = 0; i < (1 << 16); ++i) {
37-
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
38-
}
39-
}
40-
41-
~CPUInfer() {
42-
delete backend_;
43-
delete task_queue_;
44-
}
45-
46-
template <typename Func, typename Obj, typename... Args>
47-
void enqueue(Func f, Obj* obj, Args... args) {
48-
task_queue_->enqueue([=]() {
49-
std::invoke(f, *obj, args..., backend_);
50-
});
51-
}
52-
53-
void submit(std::pair<intptr_t, intptr_t> params) {
54-
void (*func)(void*) = (void (*)(void*))params.first;
55-
void* args = (void*)params.second;
56-
*((CPUInfer**)args) = this;
57-
func(args);
58-
}
59-
60-
void sync() {
61-
task_queue_->sync();
62-
}
63-
64-
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
65-
void (*func)(void*) = (void (*)(void*))params.first;
66-
void* args = (void*)params.second;
67-
*((CPUInfer**)args) = this;
68-
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
69-
}
70-
71-
static void sync_(void* cpu_infer_ptr) {
72-
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
73-
cpuinfer->sync();
74-
}
75-
76-
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
77-
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
78-
}
79-
80-
public:
81-
Backend* backend_;
82-
TaskQueue* task_queue_;
83-
};
84-
85-
#endif
10+
#ifndef CPUINFER_CPUINFER_H
11+
#define CPUINFER_CPUINFER_H
12+
13+
#include <atomic>
14+
#include <condition_variable>
15+
#include <functional>
16+
#include <mutex>
17+
#include <queue>
18+
#include <thread>
19+
#include <vector>
20+
#ifdef KTRANSFORMERS_USE_CUDA
21+
#include "vendors/cuda.h"
22+
#elif KTRANSFORMERS_USE_MUSA
23+
#include "vendors/musa.h"
24+
#elif KTRANSFORMERS_USE_ROCM
25+
#define __HIP_PLATFORM_AMD__
26+
#include "vendors/hip.h"
27+
#endif
28+
29+
#include "backend.h"
30+
#include "task_queue.h"
31+
#include "../vendors/vendor.h"
32+
33+
#include "llama.cpp/ggml-impl.h"
34+
35+
class CPUInfer {
36+
public:
37+
CPUInfer(int thread_num) {
38+
backend_ = new Backend(thread_num - 1);
39+
task_queue_ = new TaskQueue();
40+
for (int i = 0; i < (1 << 16); ++i) {
41+
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);
42+
}
43+
}
44+
45+
~CPUInfer() {
46+
delete backend_;
47+
delete task_queue_;
48+
}
49+
50+
template <typename Func, typename Obj, typename... Args>
51+
void enqueue(Func f, Obj* obj, Args... args) {
52+
task_queue_->enqueue([=]() {
53+
std::invoke(f, *obj, args..., backend_);
54+
});
55+
}
56+
57+
void submit(std::pair<intptr_t, intptr_t> params) {
58+
void (*func)(void*) = (void (*)(void*))params.first;
59+
void* args = (void*)params.second;
60+
*((CPUInfer**)args) = this;
61+
func(args);
62+
}
63+
64+
void sync() {
65+
task_queue_->sync();
66+
}
67+
68+
void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {
69+
void (*func)(void*) = (void (*)(void*))params.first;
70+
void* args = (void*)params.second;
71+
*((CPUInfer**)args) = this;
72+
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);
73+
}
74+
75+
static void sync_(void* cpu_infer_ptr) {
76+
CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;
77+
cpuinfer->sync();
78+
}
79+
80+
void sync_with_cuda_stream(intptr_t user_cuda_stream) {
81+
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);
82+
}
83+
84+
public:
85+
Backend* backend_;
86+
TaskQueue* task_queue_;
87+
};
88+
89+
#endif
+13-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
11
#pragma once
22

3-
#include <cuda_runtime.h>
3+
#include <cuda_runtime.h>
4+
#include <cuda.h>
5+
#include <cublas_v2.h>
6+
#include <cuda_bf16.h>
7+
#include <cuda_fp16.h>
8+
9+
#if CUDART_VERSION < 11020
10+
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
11+
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
12+
#define CUBLAS_COMPUTE_16F CUDA_R_16F
13+
#define CUBLAS_COMPUTE_32F CUDA_R_32F
14+
#define cublasComputeType_t cudaDataType_t
15+
#endif // CUDART_VERSION < 11020

0 commit comments

Comments
 (0)