@@ -50,6 +50,44 @@ def _find_cuda_home():
50
50
print ("No CUDA runtime is found, using CUDA_HOME='{}'" .format (cuda_home ))
51
51
return cuda_home
52
52
53
+ def _find_rocm_home ():
54
+ '''Finds the ROCm install path.'''
55
+ # Guess #1
56
+ rocm_home = os .environ .get ('ROCM_HOME' ) or os .environ .get ('ROCM_PATH' )
57
+ if rocm_home is None :
58
+ # Guess #2
59
+ try :
60
+ hipcc = subprocess .check_output (
61
+ ['which' , 'hipcc' ]).decode ().rstrip ('\r \n ' )
62
+ # this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
63
+ rocm_home = os .path .dirname (os .path .dirname (hipcc ))
64
+ if os .path .basename (rocm_home ) == 'hip' :
65
+ rocm_home = os .path .dirname (rocm_home )
66
+ except Exception :
67
+ # Guess #3
68
+ rocm_home = '/opt/rocm'
69
+ if not os .path .exists (rocm_home ):
70
+ rocm_home = None
71
+ if rocm_home and torch .version .hip is None :
72
+ print ("No ROCm runtime is found, using ROCM_HOME='{}'" .format (rocm_home ))
73
+ return rocm_home
74
+
75
+
76
+ def _join_rocm_home (* paths ):
77
+ '''
78
+ Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
79
+
80
+ This is basically a lazy way of raising an error for missing $ROCM_HOME
81
+ only once we need to get any ROCm-specific path.
82
+ '''
83
+ if ROCM_HOME is None :
84
+ raise EnvironmentError ('ROCM_HOME environment variable is not set. '
85
+ 'Please set it to your ROCm install root.' )
86
+ elif IS_WINDOWS :
87
+ raise EnvironmentError ('Building PyTorch extensions using '
88
+ 'ROCm and Windows is not supported.' )
89
+ return os .path .join (ROCM_HOME , * paths )
90
+
53
91
54
92
MINIMUM_GCC_VERSION = (4 , 9 , 0 )
55
93
MINIMUM_MSVC_VERSION = (19 , 0 , 24215 )
@@ -85,6 +123,9 @@ def _find_cuda_home():
85
123
86
124
!! WARNING !!
87
125
'''
126
+ ROCM_HOME = _find_rocm_home ()
127
+ MIOPEN_HOME = _join_rocm_home ('miopen' ) if ROCM_HOME else None
128
+ IS_HIP_EXTENSION = True if ((ROCM_HOME is not None ) and (torch .version .hip is not None )) else False
88
129
CUDA_HOME = _find_cuda_home ()
89
130
CUDNN_HOME = os .environ .get ('CUDNN_HOME' ) or os .environ .get ('CUDNN_PATH' )
90
131
# PyTorch releases have the version pattern major.minor.patch, whereas when
@@ -101,6 +142,14 @@ def _find_cuda_home():
101
142
'--expt-relaxed-constexpr'
102
143
]
103
144
145
+ COMMON_HIPCC_FLAGS = [
146
+ '-fPIC' ,
147
+ '-D__HIP_PLATFORM_HCC__=1' ,
148
+ '-DCUDA_HAS_FP16=1' ,
149
+ '-D__HIP_NO_HALF_OPERATORS__=1' ,
150
+ '-D__HIP_NO_HALF_CONVERSIONS__=1' ,
151
+ ]
152
+
104
153
JIT_EXTENSION_VERSIONER = ExtensionVersioner ()
105
154
106
155
@@ -243,12 +292,15 @@ def __init__(self, *args, **kwargs):
243
292
super (BuildExtension , self ).__init__ (* args , ** kwargs )
244
293
self .no_python_abi_suffix = kwargs .get ("no_python_abi_suffix" , False )
245
294
246
- self .use_ninja = kwargs .get ('use_ninja' , True )
295
+ self .use_ninja = kwargs .get ('use_ninja' , False if IS_HIP_EXTENSION else True )
247
296
if self .use_ninja :
248
297
# Test if we can use ninja. Fallback otherwise.
249
298
msg = ('Attempted to use ninja as the BuildExtension backend but '
250
299
'{}. Falling back to using the slow distutils backend.' )
251
- if not _is_ninja_available ():
300
+ if IS_HIP_EXTENSION :
301
+ warnings .warn (msg .format ('HIP extensions is not supported yet for ninja.' ))
302
+ self .use_ninja = False
303
+ elif not _is_ninja_available ():
252
304
warnings .warn (msg .format ('we could not find ninja.' ))
253
305
self .use_ninja = False
254
306
@@ -259,8 +311,8 @@ def build_extensions(self):
259
311
self ._define_torch_extension_name (extension )
260
312
self ._add_gnu_cpp_abi_flag (extension )
261
313
262
- # Register .cu and .cuh as valid source extensions.
263
- self .compiler .src_extensions += ['.cu' , '.cuh' ]
314
+ # Register .cu, .cuh and .hip as valid source extensions.
315
+ self .compiler .src_extensions += ['.cu' , '.cuh' , '.hip' ]
264
316
# Save the original _compile method for later.
265
317
if self .compiler .compiler_type == 'msvc' :
266
318
self .compiler ._cpp_extensions += ['.cu' , '.cuh' ]
@@ -289,15 +341,20 @@ def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
289
341
try :
290
342
original_compiler = self .compiler .compiler_so
291
343
if _is_cuda_file (src ):
292
- nvcc = _join_cuda_home ('bin' , 'nvcc' )
344
+ nvcc = ( _join_rocm_home ( 'bin' , 'hipcc' ) if IS_HIP_EXTENSION else _join_cuda_home ('bin' , 'nvcc' ) )
293
345
if not isinstance (nvcc , list ):
294
346
nvcc = [nvcc ]
295
347
self .compiler .set_executable ('compiler_so' , nvcc )
296
348
if isinstance (cflags , dict ):
297
349
cflags = cflags ['nvcc' ]
298
- cflags = unix_cuda_flags (cflags )
350
+ if IS_HIP_EXTENSION :
351
+ cflags = cflags + _get_rocm_arch_flags (cflags )
352
+ else :
353
+ cflags = unix_cuda_flags (cflags )
299
354
elif isinstance (cflags , dict ):
300
355
cflags = cflags ['cxx' ]
356
+ if IS_HIP_EXTENSION :
357
+ cflags = cflags + COMMON_HIPCC_FLAGS
301
358
append_std14_if_no_std_present (cflags )
302
359
303
360
original_compile (obj , src , ext , cc_args , cflags , pp_opts )
@@ -649,13 +706,17 @@ def CUDAExtension(name, sources, *args, **kwargs):
649
706
kwargs ['library_dirs' ] = library_dirs
650
707
651
708
libraries = kwargs .get ('libraries' , [])
652
- libraries .append ('cudart' )
653
709
libraries .append ('c10' )
654
- libraries .append ('c10_cuda' )
655
710
libraries .append ('torch' )
656
711
libraries .append ('torch_cpu' )
657
- libraries .append ('torch_cuda' )
658
712
libraries .append ('torch_python' )
713
+ if IS_HIP_EXTENSION :
714
+ libraries .append ('c10_hip' )
715
+ libraries .append ('torch_hip' )
716
+ else :
717
+ libraries .append ('cudart' )
718
+ libraries .append ('c10_cuda' )
719
+ libraries .append ('torch_cuda' )
659
720
kwargs ['libraries' ] = libraries
660
721
661
722
include_dirs = kwargs .get ('include_dirs' , [])
@@ -689,7 +750,12 @@ def include_paths(cuda=False):
689
750
os .path .join (lib_include , 'TH' ),
690
751
os .path .join (lib_include , 'THC' )
691
752
]
692
- if cuda :
753
+ if cuda and IS_HIP_EXTENSION :
754
+ paths .append (os .path .join (lib_include , 'THH' ))
755
+ paths .append (_join_rocm_home ('include' ))
756
+ if MIOPEN_HOME is not None :
757
+ paths .append (os .path .join (MIOPEN_HOME , 'include' ))
758
+ elif cuda :
693
759
cuda_home_include = _join_cuda_home ('include' )
694
760
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
695
761
# but gcc doesn't like having /usr/include passed explicitly
@@ -718,7 +784,10 @@ def library_paths(cuda=False):
718
784
lib_path = os .path .join (torch_path , 'lib' )
719
785
paths .append (lib_path )
720
786
721
- if cuda :
787
+ if cuda and IS_HIP_EXTENSION :
788
+ lib_dir = 'lib'
789
+ paths .append (_join_rocm_home (lib_dir ))
790
+ elif cuda :
722
791
if IS_WINDOWS :
723
792
lib_dir = 'lib/x64'
724
793
else :
@@ -1251,6 +1320,22 @@ def _get_cuda_arch_flags(cflags=None):
1251
1320
return list (set (flags ))
1252
1321
1253
1322
1323
+ def _get_rocm_arch_flags (cflags = None ):
1324
+ # If cflags is given, there may already be user-provided arch flags in it
1325
+ # (from `extra_compile_args`)
1326
+ if cflags is not None :
1327
+ for flag in cflags :
1328
+ if 'amdgpu-target' in flag :
1329
+ return ['-fno-gpu-rdc' ]
1330
+ return [
1331
+ '--amdgpu-target=gfx803' ,
1332
+ '--amdgpu-target=gfx900' ,
1333
+ '--amdgpu-target=gfx906' ,
1334
+ '--amdgpu-target=gfx908' ,
1335
+ '-fno-gpu-rdc'
1336
+ ]
1337
+
1338
+
1254
1339
def _get_build_directory (name , verbose ):
1255
1340
root_extensions_directory = os .environ .get ('TORCH_EXTENSIONS_DIR' )
1256
1341
if root_extensions_directory is None :
@@ -1567,4 +1652,7 @@ def _join_cuda_home(*paths):
1567
1652
1568
1653
1569
1654
def _is_cuda_file (path ):
1570
- return os .path .splitext (path )[1 ] in ['.cu' , '.cuh' ]
1655
+ valid_ext = ['.cu' , '.cuh' ]
1656
+ if IS_HIP_EXTENSION :
1657
+ valid_ext .append ('.hip' )
1658
+ return os .path .splitext (path )[1 ] in valid_ext
0 commit comments