Skip to content

Commit ae1f211

Browse files
authored
cuda : refactor into multiple files (ggml-org#6269)
1 parent ad3a050 commit ae1f211

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+9254
-9087
lines changed

.clang-tidy

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Checks: >
1212
-readability-implicit-bool-conversion,
1313
-readability-magic-numbers,
1414
-readability-uppercase-literal-suffix,
15+
-readability-simplify-boolean-expr,
1516
clang-analyzer-*,
1617
-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
1718
performance-*,

CMakeLists.txt

+7-3
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ if (LLAMA_CUBLAS)
369369
enable_language(CUDA)
370370

371371
set(GGML_HEADERS_CUDA ggml-cuda.h)
372-
set(GGML_SOURCES_CUDA ggml-cuda.cu)
372+
373+
file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
374+
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
373375

374376
add_compile_definitions(GGML_USE_CUBLAS)
375377
if (LLAMA_CUDA_FORCE_DMMV)
@@ -519,7 +521,9 @@ if (LLAMA_HIPBLAS)
519521
message(STATUS "HIP and hipBLAS found")
520522

521523
set(GGML_HEADERS_ROCM ggml-cuda.h)
522-
set(GGML_SOURCES_ROCM ggml-cuda.cu)
524+
525+
file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
526+
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
523527

524528
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
525529

@@ -543,7 +547,7 @@ if (LLAMA_HIPBLAS)
543547
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
544548
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
545549

546-
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
550+
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
547551

548552
if (LLAMA_STATIC)
549553
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")

Makefile

+20-3
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ ifdef LLAMA_CUBLAS
398398
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
399399
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
400400
OBJS += ggml-cuda.o
401+
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
401402
MK_NVCCFLAGS += -use_fast_math
402403
ifdef LLAMA_FATAL_WARNINGS
403404
MK_NVCCFLAGS += -Werror all-warnings
@@ -458,12 +459,23 @@ endif # LLAMA_CUDA_NO_PEER_COPY
458459
ifdef LLAMA_CUDA_CCBIN
459460
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
460461
endif
461-
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-common.h
462+
462463
ifdef JETSON_EOL_MODULE_DETECT
464+
define NVCC_COMPILE
463465
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
466+
endef # NVCC_COMPILE
464467
else
468+
define NVCC_COMPILE
465469
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
470+
endef # NVCC_COMPILE
466471
endif # JETSON_EOL_MODULE_DETECT
472+
473+
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
474+
$(NVCC_COMPILE)
475+
476+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
477+
$(NVCC_COMPILE)
478+
467479
endif # LLAMA_CUBLAS
468480

469481
ifdef LLAMA_CLBLAST
@@ -510,7 +522,6 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
510522
endif # LLAMA_VULKAN
511523

512524
ifdef LLAMA_HIPBLAS
513-
514525
ifeq ($(wildcard /opt/rocm),)
515526
ROCM_PATH ?= /usr
516527
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
@@ -539,8 +550,13 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
539550
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
540551
endif # LLAMA_CUDA_NO_PEER_COPY
541552
OBJS += ggml-cuda.o
542-
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
553+
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
554+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
555+
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
556+
557+
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
543558
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
559+
544560
endif # LLAMA_HIPBLAS
545561

546562
ifdef LLAMA_METAL
@@ -687,6 +703,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
687703

688704
clean:
689705
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
706+
rm -vrf ggml-cuda/*.o
690707
find examples pocs -type f -name "*.o" -delete
691708

692709
#

0 commit comments

Comments
 (0)