@@ -746,36 +746,74 @@ static const std::string gpu_pipeline =
746
746
" func.func(convert-parallel-loops-to-gpu),"
747
747
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
748
748
" func.func(insert-gpu-allocs{in-regions=1}),"
749
- // ** imex GPU passes
750
- // "drop-regions,"
751
- // "canonicalize,"
752
- // // "normalize-memrefs,"
753
- // // "gpu-decompose-memrefs,"
754
- // "func.func(lower-affine),"
755
- // "gpu-kernel-outlining,"
756
- // "canonicalize,"
757
- // "cse,"
758
- // // The following set-spirv-* passes can have client-api = opencl or
759
- // vulkan
760
- // // args
761
- // "set-spirv-capabilities{client-api=opencl},"
762
- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763
- // "canonicalize,"
764
- // "fold-memref-alias-ops,"
765
- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766
- // "spirv.module(spirv-lower-abi-attrs),"
767
- // "spirv.module(spirv-update-vce),"
768
- // // "func.func(llvm-request-c-wrappers),"
769
- // "serialize-spirv,"
770
- // "expand-strided-metadata,"
771
- // "lower-affine,"
772
- // "convert-gpu-to-gpux,"
773
- // "convert-func-to-llvm,"
774
- // "convert-math-to-llvm,"
775
- // "convert-gpux-to-llvm,"
776
- // "finalize-memref-to-llvm,"
777
- // "reconcile-unrealized-casts";
778
- // ** nv GPU passes
749
+ " drop-regions,"
750
+ " canonicalize,"
751
+ // "normalize-memrefs,"
752
+ // "gpu-decompose-memrefs,"
753
+ " func.func(lower-affine),"
754
+ " gpu-kernel-outlining,"
755
+ " canonicalize,"
756
+ " cse,"
757
+ // The following set-spirv-* passes can have client-api = opencl or vulkan
758
+ // args
759
+ " set-spirv-capabilities{client-api=opencl},"
760
+ " gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
761
+ " canonicalize,"
762
+ " fold-memref-alias-ops,"
763
+ " imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
764
+ " spirv.module(spirv-lower-abi-attrs),"
765
+ " spirv.module(spirv-update-vce),"
766
+ // "func.func(llvm-request-c-wrappers),"
767
+ " serialize-spirv,"
768
+ " expand-strided-metadata,"
769
+ " lower-affine,"
770
+ " convert-gpu-to-gpux,"
771
+ " convert-func-to-llvm,"
772
+ " convert-math-to-llvm,"
773
+ " convert-gpux-to-llvm,"
774
+ " finalize-memref-to-llvm,"
775
+ " reconcile-unrealized-casts" ;
776
+
777
+ static const std::string cuda_pipeline =
778
+ " add-gpu-regions,"
779
+ " canonicalize,"
780
+ " ndarray-dist,"
781
+ " func.func(dist-coalesce),"
782
+ " func.func(dist-infer-elementwise-cores),"
783
+ " convert-dist-to-standard,"
784
+ " canonicalize,"
785
+ " overlap-comm-and-compute,"
786
+ " add-comm-cache-keys,"
787
+ " lower-distruntime-to-idtr,"
788
+ " convert-ndarray-to-linalg,"
789
+ " canonicalize,"
790
+ " func.func(tosa-make-broadcastable),"
791
+ " func.func(tosa-to-linalg),"
792
+ " func.func(tosa-to-tensor),"
793
+ " canonicalize,"
794
+ " linalg-fuse-elementwise-ops,"
795
+ " arith-expand,"
796
+ " memref-expand,"
797
+ " arith-bufferize,"
798
+ " func-bufferize,"
799
+ " func.func(empty-tensor-to-alloc-tensor),"
800
+ " func.func(scf-bufferize),"
801
+ " func.func(tensor-bufferize),"
802
+ " func.func(bufferization-bufferize),"
803
+ " func.func(linalg-bufferize),"
804
+ " func.func(linalg-detensorize),"
805
+ " func.func(tensor-bufferize),"
806
+ " region-bufferize,"
807
+ " canonicalize,"
808
+ " func.func(finalizing-bufferize),"
809
+ " imex-remove-temporaries,"
810
+ " func.func(convert-linalg-to-parallel-loops),"
811
+ " func.func(scf-parallel-loop-fusion),"
812
+ // is add-outer-parallel-loop needed?
813
+ " func.func(imex-add-outer-parallel-loop),"
814
+ " func.func(gpu-map-parallel-loops),"
815
+ " func.func(convert-parallel-loops-to-gpu),"
816
+ " func.func(insert-gpu-allocs{in-regions=1}),"
779
817
" func.func(insert-gpu-copy),"
780
818
" drop-regions,"
781
819
" canonicalize,"
@@ -797,7 +835,9 @@ static const std::string gpu_pipeline =
797
835
798
836
const std::string _passes (get_text_env (" SHARPY_PASSES" ));
799
837
static const std::string &pass_pipeline =
800
- _passes != " " ? _passes : (useGPU () ? gpu_pipeline : cpu_pipeline);
838
+ _passes != " " ? _passes
839
+ : (useGPU () ? (useCUDA () ? cuda_pipeline : gpu_pipeline)
840
+ : cpu_pipeline);
801
841
802
842
JIT::JIT (const std::string &libidtr)
803
843
: _context (::mlir::MLIRContext::Threading::DISABLED), _pm (&_context),
@@ -849,23 +889,24 @@ JIT::JIT(const std::string &libidtr)
849
889
_crunnerlib = mlirRoot + " /lib/libmlir_c_runner_utils.so" ;
850
890
_runnerlib = mlirRoot + " /lib/libmlir_runner_utils.so" ;
851
891
if (!std::ifstream (_crunnerlib)) {
852
- throw std::runtime_error (" Cannot find libmlir_c_runner_utils.so " );
892
+ throw std::runtime_error (" Cannot find lib: " + _crunnerlib );
853
893
}
854
894
if (!std::ifstream (_runnerlib)) {
855
- throw std::runtime_error (" Cannot find libmlir_runner_utils.so " );
895
+ throw std::runtime_error (" Cannot find lib: " + _runnerlib );
856
896
}
857
897
858
898
if (useGPU ()) {
859
899
auto gpuxlibstr = get_text_env (" SHARPY_GPUX_SO" );
860
900
if (!gpuxlibstr.empty ()) {
861
901
_gpulib = std::string (gpuxlibstr);
862
902
} else {
863
- // auto imexRoot = get_text_env("IMEXROOT");
864
- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
865
- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
866
- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
867
- // for nv gpu
868
- _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
903
+ if (useCUDA ()) {
904
+ _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
905
+ } else {
906
+ auto imexRoot = get_text_env (" IMEXROOT" );
907
+ imexRoot = !imexRoot.empty () ? imexRoot : std::string (CMAKE_IMEX_ROOT);
908
+ _gpulib = imexRoot + " /lib/liblevel-zero-runtime.so" ;
909
+ }
869
910
if (!std::ifstream (_gpulib)) {
870
911
throw std::runtime_error (" Cannot find lib: " + _gpulib);
871
912
}
0 commit comments