Skip to content

Commit 8914edd

Browse files
committed
add SHARPY_USE_CUDA boolean to activate cuda pipeline
1 parent 71db3d7 commit 8914edd

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

src/include/sharpy/UtilsAndTypes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ inline bool useGPU() {
7070
auto device = get_text_env("SHARPY_DEVICE");
7171
return !(device.empty() || device == "host" || device == "cpu");
7272
}
73+
74+
inline bool useCUDA() { return get_bool_env("SHARPY_USE_CUDA"); }

src/jit/mlir.cpp

+80-39
Original file line numberDiff line numberDiff line change
@@ -748,36 +748,74 @@ static const std::string gpu_pipeline =
748748
"func.func(convert-parallel-loops-to-gpu),"
749749
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
750750
"func.func(insert-gpu-allocs{in-regions=1}),"
751-
// ** imex GPU passes
752-
// "drop-regions,"
753-
// "canonicalize,"
754-
// // "normalize-memrefs,"
755-
// // "gpu-decompose-memrefs,"
756-
// "func.func(lower-affine),"
757-
// "gpu-kernel-outlining,"
758-
// "canonicalize,"
759-
// "cse,"
760-
// // The following set-spirv-* passes can have client-api = opencl or
761-
// vulkan
762-
// // args
763-
// "set-spirv-capabilities{client-api=opencl},"
764-
// "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
765-
// "canonicalize,"
766-
// "fold-memref-alias-ops,"
767-
// "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
768-
// "spirv.module(spirv-lower-abi-attrs),"
769-
// "spirv.module(spirv-update-vce),"
770-
// // "func.func(llvm-request-c-wrappers),"
771-
// "serialize-spirv,"
772-
// "expand-strided-metadata,"
773-
// "lower-affine,"
774-
// "convert-gpu-to-gpux,"
775-
// "convert-func-to-llvm,"
776-
// "convert-math-to-llvm,"
777-
// "convert-gpux-to-llvm,"
778-
// "finalize-memref-to-llvm,"
779-
// "reconcile-unrealized-casts";
780-
// ** nv GPU passes
751+
"drop-regions,"
752+
"canonicalize,"
753+
// "normalize-memrefs,"
754+
// "gpu-decompose-memrefs,"
755+
"func.func(lower-affine),"
756+
"gpu-kernel-outlining,"
757+
"canonicalize,"
758+
"cse,"
759+
// The following set-spirv-* passes can have client-api = opencl or vulkan
760+
// args
761+
"set-spirv-capabilities{client-api=opencl},"
762+
"gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763+
"canonicalize,"
764+
"fold-memref-alias-ops,"
765+
"imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766+
"spirv.module(spirv-lower-abi-attrs),"
767+
"spirv.module(spirv-update-vce),"
768+
// "func.func(llvm-request-c-wrappers),"
769+
"serialize-spirv,"
770+
"expand-strided-metadata,"
771+
"lower-affine,"
772+
"convert-gpu-to-gpux,"
773+
"convert-func-to-llvm,"
774+
"convert-math-to-llvm,"
775+
"convert-gpux-to-llvm,"
776+
"finalize-memref-to-llvm,"
777+
"reconcile-unrealized-casts";
778+
779+
static const std::string cuda_pipeline =
780+
"add-gpu-regions,"
781+
"canonicalize,"
782+
"ndarray-dist,"
783+
"func.func(dist-coalesce),"
784+
"func.func(dist-infer-elementwise-cores),"
785+
"convert-dist-to-standard,"
786+
"canonicalize,"
787+
"overlap-comm-and-compute,"
788+
"add-comm-cache-keys,"
789+
"lower-distruntime-to-idtr,"
790+
"convert-ndarray-to-linalg,"
791+
"canonicalize,"
792+
"func.func(tosa-make-broadcastable),"
793+
"func.func(tosa-to-linalg),"
794+
"func.func(tosa-to-tensor),"
795+
"canonicalize,"
796+
"linalg-fuse-elementwise-ops,"
797+
"arith-expand,"
798+
"memref-expand,"
799+
"arith-bufferize,"
800+
"func-bufferize,"
801+
"func.func(empty-tensor-to-alloc-tensor),"
802+
"func.func(scf-bufferize),"
803+
"func.func(tensor-bufferize),"
804+
"func.func(bufferization-bufferize),"
805+
"func.func(linalg-bufferize),"
806+
"func.func(linalg-detensorize),"
807+
"func.func(tensor-bufferize),"
808+
"region-bufferize,"
809+
"canonicalize,"
810+
"func.func(finalizing-bufferize),"
811+
"imex-remove-temporaries,"
812+
"func.func(convert-linalg-to-parallel-loops),"
813+
"func.func(scf-parallel-loop-fusion),"
814+
// is add-outer-parallel-loop needed?
815+
"func.func(imex-add-outer-parallel-loop),"
816+
"func.func(gpu-map-parallel-loops),"
817+
"func.func(convert-parallel-loops-to-gpu),"
818+
"func.func(insert-gpu-allocs{in-regions=1}),"
781819
"func.func(insert-gpu-copy),"
782820
"drop-regions,"
783821
"canonicalize,"
@@ -799,7 +837,9 @@ static const std::string gpu_pipeline =
799837

800838
const std::string _passes(get_text_env("SHARPY_PASSES"));
801839
static const std::string &pass_pipeline =
802-
_passes != "" ? _passes : (useGPU() ? gpu_pipeline : cpu_pipeline);
840+
_passes != "" ? _passes
841+
: (useGPU() ? (useCUDA() ? cuda_pipeline : gpu_pipeline)
842+
: cpu_pipeline);
803843

804844
JIT::JIT(const std::string &libidtr)
805845
: _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
@@ -851,23 +891,24 @@ JIT::JIT(const std::string &libidtr)
851891
_crunnerlib = mlirRoot + "/lib/libmlir_c_runner_utils.so";
852892
_runnerlib = mlirRoot + "/lib/libmlir_runner_utils.so";
853893
if (!std::ifstream(_crunnerlib)) {
854-
throw std::runtime_error("Cannot find libmlir_c_runner_utils.so");
894+
throw std::runtime_error("Cannot find lib: " + _crunnerlib);
855895
}
856896
if (!std::ifstream(_runnerlib)) {
857-
throw std::runtime_error("Cannot find libmlir_runner_utils.so");
897+
throw std::runtime_error("Cannot find lib: " + _runnerlib);
858898
}
859899

860900
if (useGPU()) {
861901
auto gpuxlibstr = get_text_env("SHARPY_GPUX_SO");
862902
if (!gpuxlibstr.empty()) {
863903
_gpulib = std::string(gpuxlibstr);
864904
} else {
865-
// auto imexRoot = get_text_env("IMEXROOT");
866-
// imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
867-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
868-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
869-
// for nv gpu
870-
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
905+
if (useCUDA()) {
906+
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
907+
} else {
908+
auto imexRoot = get_text_env("IMEXROOT");
909+
imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
910+
_gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
911+
}
871912
if (!std::ifstream(_gpulib)) {
872913
throw std::runtime_error("Cannot find lib: " + _gpulib);
873914
}

0 commit comments

Comments
 (0)