Skip to content

Commit eb75cd4

Browse files
committed
add SHARPY_USE_CUDA boolean to activate cuda pipeline
1 parent 2f2fbca commit eb75cd4

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

Diff for: src/include/sharpy/UtilsAndTypes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ inline bool useGPU() {
7070
auto device = get_text_env("SHARPY_DEVICE");
7171
return !(device.empty() || device == "host" || device == "cpu");
7272
}
73+
74+
inline bool useCUDA() { return get_bool_env("SHARPY_USE_CUDA"); }

Diff for: src/jit/mlir.cpp

+80-39
Original file line numberDiff line numberDiff line change
@@ -746,36 +746,74 @@ static const std::string gpu_pipeline =
746746
"func.func(convert-parallel-loops-to-gpu),"
747747
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
748748
"func.func(insert-gpu-allocs{in-regions=1}),"
749-
// ** imex GPU passes
750-
// "drop-regions,"
751-
// "canonicalize,"
752-
// // "normalize-memrefs,"
753-
// // "gpu-decompose-memrefs,"
754-
// "func.func(lower-affine),"
755-
// "gpu-kernel-outlining,"
756-
// "canonicalize,"
757-
// "cse,"
758-
// // The following set-spirv-* passes can have client-api = opencl or
759-
// vulkan
760-
// // args
761-
// "set-spirv-capabilities{client-api=opencl},"
762-
// "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763-
// "canonicalize,"
764-
// "fold-memref-alias-ops,"
765-
// "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766-
// "spirv.module(spirv-lower-abi-attrs),"
767-
// "spirv.module(spirv-update-vce),"
768-
// // "func.func(llvm-request-c-wrappers),"
769-
// "serialize-spirv,"
770-
// "expand-strided-metadata,"
771-
// "lower-affine,"
772-
// "convert-gpu-to-gpux,"
773-
// "convert-func-to-llvm,"
774-
// "convert-math-to-llvm,"
775-
// "convert-gpux-to-llvm,"
776-
// "finalize-memref-to-llvm,"
777-
// "reconcile-unrealized-casts";
778-
// ** nv GPU passes
749+
"drop-regions,"
750+
"canonicalize,"
751+
// "normalize-memrefs,"
752+
// "gpu-decompose-memrefs,"
753+
"func.func(lower-affine),"
754+
"gpu-kernel-outlining,"
755+
"canonicalize,"
756+
"cse,"
757+
// The following set-spirv-* passes can have client-api = opencl or vulkan
758+
// args
759+
"set-spirv-capabilities{client-api=opencl},"
760+
"gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
761+
"canonicalize,"
762+
"fold-memref-alias-ops,"
763+
"imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
764+
"spirv.module(spirv-lower-abi-attrs),"
765+
"spirv.module(spirv-update-vce),"
766+
// "func.func(llvm-request-c-wrappers),"
767+
"serialize-spirv,"
768+
"expand-strided-metadata,"
769+
"lower-affine,"
770+
"convert-gpu-to-gpux,"
771+
"convert-func-to-llvm,"
772+
"convert-math-to-llvm,"
773+
"convert-gpux-to-llvm,"
774+
"finalize-memref-to-llvm,"
775+
"reconcile-unrealized-casts";
776+
777+
static const std::string cuda_pipeline =
778+
"add-gpu-regions,"
779+
"canonicalize,"
780+
"ndarray-dist,"
781+
"func.func(dist-coalesce),"
782+
"func.func(dist-infer-elementwise-cores),"
783+
"convert-dist-to-standard,"
784+
"canonicalize,"
785+
"overlap-comm-and-compute,"
786+
"add-comm-cache-keys,"
787+
"lower-distruntime-to-idtr,"
788+
"convert-ndarray-to-linalg,"
789+
"canonicalize,"
790+
"func.func(tosa-make-broadcastable),"
791+
"func.func(tosa-to-linalg),"
792+
"func.func(tosa-to-tensor),"
793+
"canonicalize,"
794+
"linalg-fuse-elementwise-ops,"
795+
"arith-expand,"
796+
"memref-expand,"
797+
"arith-bufferize,"
798+
"func-bufferize,"
799+
"func.func(empty-tensor-to-alloc-tensor),"
800+
"func.func(scf-bufferize),"
801+
"func.func(tensor-bufferize),"
802+
"func.func(bufferization-bufferize),"
803+
"func.func(linalg-bufferize),"
804+
"func.func(linalg-detensorize),"
805+
"func.func(tensor-bufferize),"
806+
"region-bufferize,"
807+
"canonicalize,"
808+
"func.func(finalizing-bufferize),"
809+
"imex-remove-temporaries,"
810+
"func.func(convert-linalg-to-parallel-loops),"
811+
"func.func(scf-parallel-loop-fusion),"
812+
// is add-outer-parallel-loop needed?
813+
"func.func(imex-add-outer-parallel-loop),"
814+
"func.func(gpu-map-parallel-loops),"
815+
"func.func(convert-parallel-loops-to-gpu),"
816+
"func.func(insert-gpu-allocs{in-regions=1}),"
779817
"func.func(insert-gpu-copy),"
780818
"drop-regions,"
781819
"canonicalize,"
@@ -797,7 +835,9 @@ static const std::string gpu_pipeline =
797835

798836
const std::string _passes(get_text_env("SHARPY_PASSES"));
799837
static const std::string &pass_pipeline =
800-
_passes != "" ? _passes : (useGPU() ? gpu_pipeline : cpu_pipeline);
838+
_passes != "" ? _passes
839+
: (useGPU() ? (useCUDA() ? cuda_pipeline : gpu_pipeline)
840+
: cpu_pipeline);
801841

802842
JIT::JIT(const std::string &libidtr)
803843
: _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
@@ -849,23 +889,24 @@ JIT::JIT(const std::string &libidtr)
849889
_crunnerlib = mlirRoot + "/lib/libmlir_c_runner_utils.so";
850890
_runnerlib = mlirRoot + "/lib/libmlir_runner_utils.so";
851891
if (!std::ifstream(_crunnerlib)) {
852-
throw std::runtime_error("Cannot find libmlir_c_runner_utils.so");
892+
throw std::runtime_error("Cannot find lib: " + _crunnerlib);
853893
}
854894
if (!std::ifstream(_runnerlib)) {
855-
throw std::runtime_error("Cannot find libmlir_runner_utils.so");
895+
throw std::runtime_error("Cannot find lib: " + _runnerlib);
856896
}
857897

858898
if (useGPU()) {
859899
auto gpuxlibstr = get_text_env("SHARPY_GPUX_SO");
860900
if (!gpuxlibstr.empty()) {
861901
_gpulib = std::string(gpuxlibstr);
862902
} else {
863-
// auto imexRoot = get_text_env("IMEXROOT");
864-
// imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
865-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
866-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
867-
// for nv gpu
868-
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
903+
if (useCUDA()) {
904+
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
905+
} else {
906+
auto imexRoot = get_text_env("IMEXROOT");
907+
imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
908+
_gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
909+
}
869910
if (!std::ifstream(_gpulib)) {
870911
throw std::runtime_error("Cannot find lib: " + _gpulib);
871912
}

0 commit comments

Comments
 (0)