@@ -746,36 +746,74 @@ static const std::string gpu_pipeline =
746746 "func.func(convert-parallel-loops-to-gpu),"
747747 // insert-gpu-allocs pass can have client-api = opencl or vulkan args
748748 "func.func(insert-gpu-allocs{in-regions=1}),"
749- // ** imex GPU passes
750- // "drop-regions,"
751- // "canonicalize,"
752- // // "normalize-memrefs,"
753- // // "gpu-decompose-memrefs,"
754- // "func.func(lower-affine),"
755- // "gpu-kernel-outlining,"
756- // "canonicalize,"
757- // "cse,"
758- // // The following set-spirv-* passes can have client-api = opencl or
759- // vulkan
760- // // args
761- // "set-spirv-capabilities{client-api=opencl},"
762- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763- // "canonicalize,"
764- // "fold-memref-alias-ops,"
765- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766- // "spirv.module(spirv-lower-abi-attrs),"
767- // "spirv.module(spirv-update-vce),"
768- // // "func.func(llvm-request-c-wrappers),"
769- // "serialize-spirv,"
770- // "expand-strided-metadata,"
771- // "lower-affine,"
772- // "convert-gpu-to-gpux,"
773- // "convert-func-to-llvm,"
774- // "convert-math-to-llvm,"
775- // "convert-gpux-to-llvm,"
776- // "finalize-memref-to-llvm,"
777- // "reconcile-unrealized-casts";
778- // ** nv GPU passes
749+ "drop-regions,"
750+ "canonicalize,"
751+ // "normalize-memrefs,"
752+ // "gpu-decompose-memrefs,"
753+ "func.func(lower-affine),"
754+ "gpu-kernel-outlining,"
755+ "canonicalize,"
756+ "cse,"
757+ // The following set-spirv-* passes can have client-api = opencl or vulkan
758+ // args
759+ "set-spirv-capabilities{client-api=opencl},"
760+ "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
761+ "canonicalize,"
762+ "fold-memref-alias-ops,"
763+ "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
764+ "spirv.module(spirv-lower-abi-attrs),"
765+ "spirv.module(spirv-update-vce),"
766+ // "func.func(llvm-request-c-wrappers),"
767+ "serialize-spirv,"
768+ "expand-strided-metadata,"
769+ "lower-affine,"
770+ "convert-gpu-to-gpux,"
771+ "convert-func-to-llvm,"
772+ "convert-math-to-llvm,"
773+ "convert-gpux-to-llvm,"
774+ "finalize-memref-to-llvm,"
775+ "reconcile-unrealized-casts";
776+
777+ static const std::string cuda_pipeline =
778+ "add-gpu-regions,"
779+ "canonicalize,"
780+ "ndarray-dist,"
781+ "func.func(dist-coalesce),"
782+ "func.func(dist-infer-elementwise-cores),"
783+ "convert-dist-to-standard,"
784+ "canonicalize,"
785+ "overlap-comm-and-compute,"
786+ "add-comm-cache-keys,"
787+ "lower-distruntime-to-idtr,"
788+ "convert-ndarray-to-linalg,"
789+ "canonicalize,"
790+ "func.func(tosa-make-broadcastable),"
791+ "func.func(tosa-to-linalg),"
792+ "func.func(tosa-to-tensor),"
793+ "canonicalize,"
794+ "linalg-fuse-elementwise-ops,"
795+ "arith-expand,"
796+ "memref-expand,"
797+ "arith-bufferize,"
798+ "func-bufferize,"
799+ "func.func(empty-tensor-to-alloc-tensor),"
800+ "func.func(scf-bufferize),"
801+ "func.func(tensor-bufferize),"
802+ "func.func(bufferization-bufferize),"
803+ "func.func(linalg-bufferize),"
804+ "func.func(linalg-detensorize),"
805+ "func.func(tensor-bufferize),"
806+ "region-bufferize,"
807+ "canonicalize,"
808+ "func.func(finalizing-bufferize),"
809+ "imex-remove-temporaries,"
810+ "func.func(convert-linalg-to-parallel-loops),"
811+ "func.func(scf-parallel-loop-fusion),"
812+ // is add-outer-parallel-loop needed?
813+ "func.func(imex-add-outer-parallel-loop),"
814+ "func.func(gpu-map-parallel-loops),"
815+ "func.func(convert-parallel-loops-to-gpu),"
816+ "func.func(insert-gpu-allocs{in-regions=1}),"
779817 "func.func(insert-gpu-copy),"
780818 "drop-regions,"
781819 "canonicalize,"
@@ -797,7 +835,9 @@ static const std::string gpu_pipeline =
797835
798836const std::string _passes(get_text_env("SHARPY_PASSES"));
799837static const std::string &pass_pipeline =
800- _passes != "" ? _passes : (useGPU() ? gpu_pipeline : cpu_pipeline);
838+ _passes != "" ? _passes
839+ : (useGPU() ? (useCUDA() ? cuda_pipeline : gpu_pipeline)
840+ : cpu_pipeline);
801841
802842JIT::JIT(const std::string &libidtr)
803843 : _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
@@ -849,23 +889,24 @@ JIT::JIT(const std::string &libidtr)
849889 _crunnerlib = mlirRoot + "/lib/libmlir_c_runner_utils.so";
850890 _runnerlib = mlirRoot + "/lib/libmlir_runner_utils.so";
851891 if (!std::ifstream(_crunnerlib)) {
852- throw std::runtime_error("Cannot find libmlir_c_runner_utils.so" );
892+ throw std::runtime_error("Cannot find lib: " + _crunnerlib );
853893 }
854894 if (!std::ifstream(_runnerlib)) {
855- throw std::runtime_error("Cannot find libmlir_runner_utils.so" );
895+ throw std::runtime_error("Cannot find lib: " + _runnerlib );
856896 }
857897
858898 if (useGPU()) {
859899 auto gpuxlibstr = get_text_env("SHARPY_GPUX_SO");
860900 if (!gpuxlibstr.empty()) {
861901 _gpulib = std::string(gpuxlibstr);
862902 } else {
863- // auto imexRoot = get_text_env("IMEXROOT");
864- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
865- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
866- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
867- // for nv gpu
868- _gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
903+ if (useCUDA()) {
904+ _gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
905+ } else {
906+ auto imexRoot = get_text_env("IMEXROOT");
907+ imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
908+ _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
909+ }
869910 if (!std::ifstream(_gpulib)) {
870911 throw std::runtime_error("Cannot find lib: " + _gpulib);
871912 }
0 commit comments