@@ -50,6 +50,44 @@ def _find_cuda_home():
5050 print ("No CUDA runtime is found, using CUDA_HOME='{}'" .format (cuda_home ))
5151 return cuda_home
5252
53+ def _find_rocm_home ():
54+ '''Finds the ROCm install path.'''
55+ # Guess #1
56+ rocm_home = os .environ .get ('ROCM_HOME' ) or os .environ .get ('ROCM_PATH' )
57+ if rocm_home is None :
58+ # Guess #2
59+ try :
60+ hipcc = subprocess .check_output (
61+ ['which' , 'hipcc' ]).decode ().rstrip ('\r \n ' )
62+ # this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
63+ rocm_home = os .path .dirname (os .path .dirname (hipcc ))
64+ if os .path .basename (rocm_home ) == 'hip' :
65+ rocm_home = os .path .dirname (rocm_home )
66+ except Exception :
67+ # Guess #3
68+ rocm_home = '/opt/rocm'
69+ if not os .path .exists (rocm_home ):
70+ rocm_home = None
71+ if rocm_home and torch .version .hip is None :
72+ print ("No ROCm runtime is found, using ROCM_HOME='{}'" .format (rocm_home ))
73+ return rocm_home
74+
75+
76+ def _join_rocm_home (* paths ):
77+ '''
78+ Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
79+
80+ This is basically a lazy way of raising an error for missing $ROCM_HOME
81+ only once we need to get any ROCm-specific path.
82+ '''
83+ if ROCM_HOME is None :
84+ raise EnvironmentError ('ROCM_HOME environment variable is not set. '
85+ 'Please set it to your ROCm install root.' )
86+ elif IS_WINDOWS :
87+ raise EnvironmentError ('Building PyTorch extensions using '
88+ 'ROCm and Windows is not supported.' )
89+ return os .path .join (ROCM_HOME , * paths )
90+
5391
5492MINIMUM_GCC_VERSION = (4 , 9 , 0 )
5593MINIMUM_MSVC_VERSION = (19 , 0 , 24215 )
@@ -85,6 +123,9 @@ def _find_cuda_home():
85123
86124 !! WARNING !!
87125'''
126+ ROCM_HOME = _find_rocm_home ()
127+ MIOPEN_HOME = _join_rocm_home ('miopen' ) if ROCM_HOME else None
128+ IS_HIP_EXTENSION = True if ((ROCM_HOME is not None ) and (torch .version .hip is not None )) else False
88129CUDA_HOME = _find_cuda_home ()
89130CUDNN_HOME = os .environ .get ('CUDNN_HOME' ) or os .environ .get ('CUDNN_PATH' )
90131# PyTorch releases have the version pattern major.minor.patch, whereas when
@@ -101,6 +142,14 @@ def _find_cuda_home():
101142 '--expt-relaxed-constexpr'
102143]
103144
145+ COMMON_HIPCC_FLAGS = [
146+ '-fPIC' ,
147+ '-D__HIP_PLATFORM_HCC__=1' ,
148+ '-DCUDA_HAS_FP16=1' ,
149+ '-D__HIP_NO_HALF_OPERATORS__=1' ,
150+ '-D__HIP_NO_HALF_CONVERSIONS__=1' ,
151+ ]
152+
104153JIT_EXTENSION_VERSIONER = ExtensionVersioner ()
105154
106155
@@ -243,12 +292,15 @@ def __init__(self, *args, **kwargs):
243292 super (BuildExtension , self ).__init__ (* args , ** kwargs )
244293 self .no_python_abi_suffix = kwargs .get ("no_python_abi_suffix" , False )
245294
246- self .use_ninja = kwargs .get ('use_ninja' , True )
295+ self .use_ninja = kwargs .get ('use_ninja' , False if IS_HIP_EXTENSION else True )
247296 if self .use_ninja :
248297 # Test if we can use ninja. Fallback otherwise.
249298 msg = ('Attempted to use ninja as the BuildExtension backend but '
250299 '{}. Falling back to using the slow distutils backend.' )
251- if not _is_ninja_available ():
300+ if IS_HIP_EXTENSION :
301+ warnings .warn (msg .format ('HIP extensions is not supported yet for ninja.' ))
302+ self .use_ninja = False
303+ elif not _is_ninja_available ():
252304 warnings .warn (msg .format ('we could not find ninja.' ))
253305 self .use_ninja = False
254306
@@ -259,8 +311,8 @@ def build_extensions(self):
259311 self ._define_torch_extension_name (extension )
260312 self ._add_gnu_cpp_abi_flag (extension )
261313
262- # Register .cu and .cuh as valid source extensions.
263- self .compiler .src_extensions += ['.cu' , '.cuh' ]
314+ # Register .cu, .cuh and .hip as valid source extensions.
315+ self .compiler .src_extensions += ['.cu' , '.cuh' , '.hip' ]
264316 # Save the original _compile method for later.
265317 if self .compiler .compiler_type == 'msvc' :
266318 self .compiler ._cpp_extensions += ['.cu' , '.cuh' ]
@@ -289,15 +341,20 @@ def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
289341 try :
290342 original_compiler = self .compiler .compiler_so
291343 if _is_cuda_file (src ):
292- nvcc = _join_cuda_home ('bin' , 'nvcc' )
344+ nvcc = ( _join_rocm_home ( 'bin' , 'hipcc' ) if IS_HIP_EXTENSION else _join_cuda_home ('bin' , 'nvcc' ) )
293345 if not isinstance (nvcc , list ):
294346 nvcc = [nvcc ]
295347 self .compiler .set_executable ('compiler_so' , nvcc )
296348 if isinstance (cflags , dict ):
297349 cflags = cflags ['nvcc' ]
298- cflags = unix_cuda_flags (cflags )
350+ if IS_HIP_EXTENSION :
351+ cflags = cflags + _get_rocm_arch_flags (cflags )
352+ else :
353+ cflags = unix_cuda_flags (cflags )
299354 elif isinstance (cflags , dict ):
300355 cflags = cflags ['cxx' ]
356+ if IS_HIP_EXTENSION :
357+ cflags = cflags + COMMON_HIPCC_FLAGS
301358 append_std14_if_no_std_present (cflags )
302359
303360 original_compile (obj , src , ext , cc_args , cflags , pp_opts )
@@ -649,13 +706,17 @@ def CUDAExtension(name, sources, *args, **kwargs):
649706 kwargs ['library_dirs' ] = library_dirs
650707
651708 libraries = kwargs .get ('libraries' , [])
652- libraries .append ('cudart' )
653709 libraries .append ('c10' )
654- libraries .append ('c10_cuda' )
655710 libraries .append ('torch' )
656711 libraries .append ('torch_cpu' )
657- libraries .append ('torch_cuda' )
658712 libraries .append ('torch_python' )
713+ if IS_HIP_EXTENSION :
714+ libraries .append ('c10_hip' )
715+ libraries .append ('torch_hip' )
716+ else :
717+ libraries .append ('cudart' )
718+ libraries .append ('c10_cuda' )
719+ libraries .append ('torch_cuda' )
659720 kwargs ['libraries' ] = libraries
660721
661722 include_dirs = kwargs .get ('include_dirs' , [])
@@ -689,7 +750,12 @@ def include_paths(cuda=False):
689750 os .path .join (lib_include , 'TH' ),
690751 os .path .join (lib_include , 'THC' )
691752 ]
692- if cuda :
753+ if cuda and IS_HIP_EXTENSION :
754+ paths .append (os .path .join (lib_include , 'THH' ))
755+ paths .append (_join_rocm_home ('include' ))
756+ if MIOPEN_HOME is not None :
757+ paths .append (os .path .join (MIOPEN_HOME , 'include' ))
758+ elif cuda :
693759 cuda_home_include = _join_cuda_home ('include' )
694760 # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
695761 # but gcc doesn't like having /usr/include passed explicitly
@@ -718,7 +784,10 @@ def library_paths(cuda=False):
718784 lib_path = os .path .join (torch_path , 'lib' )
719785 paths .append (lib_path )
720786
721- if cuda :
787+ if cuda and IS_HIP_EXTENSION :
788+ lib_dir = 'lib'
789+ paths .append (_join_rocm_home (lib_dir ))
790+ elif cuda :
722791 if IS_WINDOWS :
723792 lib_dir = 'lib/x64'
724793 else :
@@ -1251,6 +1320,22 @@ def _get_cuda_arch_flags(cflags=None):
12511320 return list (set (flags ))
12521321
12531322
1323+ def _get_rocm_arch_flags (cflags = None ):
1324+ # If cflags is given, there may already be user-provided arch flags in it
1325+ # (from `extra_compile_args`)
1326+ if cflags is not None :
1327+ for flag in cflags :
1328+ if 'amdgpu-target' in flag :
1329+ return ['-fno-gpu-rdc' ]
1330+ return [
1331+ '--amdgpu-target=gfx803' ,
1332+ '--amdgpu-target=gfx900' ,
1333+ '--amdgpu-target=gfx906' ,
1334+ '--amdgpu-target=gfx908' ,
1335+ '-fno-gpu-rdc'
1336+ ]
1337+
1338+
12541339def _get_build_directory (name , verbose ):
12551340 root_extensions_directory = os .environ .get ('TORCH_EXTENSIONS_DIR' )
12561341 if root_extensions_directory is None :
@@ -1567,4 +1652,7 @@ def _join_cuda_home(*paths):
15671652
15681653
15691654def _is_cuda_file (path ):
1570- return os .path .splitext (path )[1 ] in ['.cu' , '.cuh' ]
1655+ valid_ext = ['.cu' , '.cuh' ]
1656+ if IS_HIP_EXTENSION :
1657+ valid_ext .append ('.hip' )
1658+ return os .path .splitext (path )[1 ] in valid_ext
0 commit comments