Skip to content

Commit 7dc0831

Browse files
committed
add SHARPY_USE_CUDA boolean to activate cuda pipeline
1 parent c1903b1 commit 7dc0831

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

src/include/sharpy/UtilsAndTypes.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ inline bool useGPU() {
7070
auto device = get_text_env("SHARPY_DEVICE");
7171
return !(device.empty() || device == "host" || device == "cpu");
7272
}
73+
74+
inline bool useCUDA() { return get_bool_env("SHARPY_USE_CUDA"); }

src/jit/mlir.cpp

+80-39
Original file line numberDiff line numberDiff line change
@@ -757,36 +757,74 @@ static const std::string gpu_pipeline =
757757
"func.func(convert-parallel-loops-to-gpu),"
758758
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
759759
"func.func(insert-gpu-allocs{in-regions=1}),"
760-
// ** imex GPU passes
761-
// "drop-regions,"
762-
// "canonicalize,"
763-
// // "normalize-memrefs,"
764-
// // "gpu-decompose-memrefs,"
765-
// "func.func(lower-affine),"
766-
// "gpu-kernel-outlining,"
767-
// "canonicalize,"
768-
// "cse,"
769-
// // The following set-spirv-* passes can have client-api = opencl or
770-
// vulkan
771-
// // args
772-
// "set-spirv-capabilities{client-api=opencl},"
773-
// "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
774-
// "canonicalize,"
775-
// "fold-memref-alias-ops,"
776-
// "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
777-
// "spirv.module(spirv-lower-abi-attrs),"
778-
// "spirv.module(spirv-update-vce),"
779-
// // "func.func(llvm-request-c-wrappers),"
780-
// "serialize-spirv,"
781-
// "expand-strided-metadata,"
782-
// "lower-affine,"
783-
// "convert-gpu-to-gpux,"
784-
// "convert-func-to-llvm,"
785-
// "convert-math-to-llvm,"
786-
// "convert-gpux-to-llvm,"
787-
// "finalize-memref-to-llvm,"
788-
// "reconcile-unrealized-casts";
789-
// ** nv GPU passes
760+
"drop-regions,"
761+
"canonicalize,"
762+
// "normalize-memrefs,"
763+
// "gpu-decompose-memrefs,"
764+
"func.func(lower-affine),"
765+
"gpu-kernel-outlining,"
766+
"canonicalize,"
767+
"cse,"
768+
// The following set-spirv-* passes can have client-api = opencl or vulkan
769+
// args
770+
"set-spirv-capabilities{client-api=opencl},"
771+
"gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
772+
"canonicalize,"
773+
"fold-memref-alias-ops,"
774+
"imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
775+
"spirv.module(spirv-lower-abi-attrs),"
776+
"spirv.module(spirv-update-vce),"
777+
// "func.func(llvm-request-c-wrappers),"
778+
"serialize-spirv,"
779+
"expand-strided-metadata,"
780+
"lower-affine,"
781+
"convert-gpu-to-gpux,"
782+
"convert-func-to-llvm,"
783+
"convert-math-to-llvm,"
784+
"convert-gpux-to-llvm,"
785+
"finalize-memref-to-llvm,"
786+
"reconcile-unrealized-casts";
787+
788+
static const std::string cuda_pipeline =
789+
"add-gpu-regions,"
790+
"canonicalize,"
791+
"ndarray-dist,"
792+
"func.func(dist-coalesce),"
793+
"func.func(dist-infer-elementwise-cores),"
794+
"convert-dist-to-standard,"
795+
"canonicalize,"
796+
"overlap-comm-and-compute,"
797+
"add-comm-cache-keys,"
798+
"lower-distruntime-to-idtr,"
799+
"convert-ndarray-to-linalg,"
800+
"canonicalize,"
801+
"func.func(tosa-make-broadcastable),"
802+
"func.func(tosa-to-linalg),"
803+
"func.func(tosa-to-tensor),"
804+
"canonicalize,"
805+
"linalg-fuse-elementwise-ops,"
806+
"arith-expand,"
807+
"memref-expand,"
808+
"arith-bufferize,"
809+
"func-bufferize,"
810+
"func.func(empty-tensor-to-alloc-tensor),"
811+
"func.func(scf-bufferize),"
812+
"func.func(tensor-bufferize),"
813+
"func.func(bufferization-bufferize),"
814+
"func.func(linalg-bufferize),"
815+
"func.func(linalg-detensorize),"
816+
"func.func(tensor-bufferize),"
817+
"region-bufferize,"
818+
"canonicalize,"
819+
"func.func(finalizing-bufferize),"
820+
"imex-remove-temporaries,"
821+
"func.func(convert-linalg-to-parallel-loops),"
822+
"func.func(scf-parallel-loop-fusion),"
823+
// is add-outer-parallel-loop needed?
824+
"func.func(imex-add-outer-parallel-loop),"
825+
"func.func(gpu-map-parallel-loops),"
826+
"func.func(convert-parallel-loops-to-gpu),"
827+
"func.func(insert-gpu-allocs{in-regions=1}),"
790828
"func.func(insert-gpu-copy),"
791829
"drop-regions,"
792830
"canonicalize,"
@@ -808,7 +846,9 @@ static const std::string gpu_pipeline =
808846

809847
const std::string _passes(get_text_env("SHARPY_PASSES"));
810848
static const std::string &pass_pipeline =
811-
_passes != "" ? _passes : (useGPU() ? gpu_pipeline : cpu_pipeline);
849+
_passes != "" ? _passes
850+
: (useGPU() ? (useCUDA() ? cuda_pipeline : gpu_pipeline)
851+
: cpu_pipeline);
812852

813853
JIT::JIT(const std::string &libidtr)
814854
: _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
@@ -860,23 +900,24 @@ JIT::JIT(const std::string &libidtr)
860900
_crunnerlib = mlirRoot + "/lib/libmlir_c_runner_utils.so";
861901
_runnerlib = mlirRoot + "/lib/libmlir_runner_utils.so";
862902
if (!std::ifstream(_crunnerlib)) {
863-
throw std::runtime_error("Cannot find libmlir_c_runner_utils.so");
903+
throw std::runtime_error("Cannot find lib: " + _crunnerlib);
864904
}
865905
if (!std::ifstream(_runnerlib)) {
866-
throw std::runtime_error("Cannot find libmlir_runner_utils.so");
906+
throw std::runtime_error("Cannot find lib: " + _runnerlib);
867907
}
868908

869909
if (useGPU()) {
870910
auto gpuxlibstr = get_text_env("SHARPY_GPUX_SO");
871911
if (!gpuxlibstr.empty()) {
872912
_gpulib = std::string(gpuxlibstr);
873913
} else {
874-
// auto imexRoot = get_text_env("IMEXROOT");
875-
// imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
876-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
877-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
878-
// for nv gpu
879-
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
914+
if (useCUDA()) {
915+
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
916+
} else {
917+
auto imexRoot = get_text_env("IMEXROOT");
918+
imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
919+
_gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
920+
}
880921
if (!std::ifstream(_gpulib)) {
881922
throw std::runtime_error("Cannot find lib: " + _gpulib);
882923
}

0 commit comments

Comments
 (0)