Skip to content

Commit 616beb1

Browse files
ashishfacebook-github-bot
ashish
authored andcommitted
[ROCm] Added support for pytorch extensions to use HIP (pytorch#32669)
Summary: This pull request has changes for: 1. Enabling a torch module with HIP code to be compiled by cpp_extensions.py 2. Fixes for hipify module to be able to be used by a torch extension cc: ezyang iotamudelta jeffdaily Pull Request resolved: pytorch#32669 Differential Revision: D20033893 Pulled By: zou3519 fbshipit-source-id: fd6ddc8cdcd3930f41008636bb2bc9dd26cdb008
1 parent ca8e025 commit 616beb1

File tree

5 files changed

+145
-35
lines changed

5 files changed

+145
-35
lines changed

test/cpp_extensions/setup.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from setuptools import setup
55
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
6-
from torch.utils.cpp_extension import CUDA_HOME
6+
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
77

88
if sys.platform == 'win32':
99
vc_version = os.getenv('VCToolsVersion', '')
@@ -35,6 +35,22 @@
3535
extra_compile_args={'cxx': CXX_FLAGS,
3636
'nvcc': ['-O2']})
3737
ext_modules.append(extension)
38+
elif torch.cuda.is_available() and ROCM_HOME is not None:
39+
from torch.utils.hipify import hipify_python
40+
this_dir = os.path.dirname(os.path.abspath(__file__))
41+
hipify_python.hipify(
42+
project_directory=this_dir,
43+
output_directory=this_dir,
44+
includes="./*",
45+
show_detailed=True,
46+
is_pytorch_extension=True,)
47+
extension = CUDAExtension(
48+
'torch_test_cpp_extension.cuda', [
49+
'cuda_extension.cpp',
50+
'hip/hip_extension_kernel.hip',
51+
'hip/hip_extension_kernel2.hip',
52+
])
53+
ext_modules.append(extension)
3854

3955
setup(
4056
name='torch_test_cpp_extension',

test/run_test.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
TESTS = [
2323
'test_autograd',
2424
'test_complex',
25-
'test_cpp_extensions_aot',
2625
'test_cpp_extensions_aot_no_ninja',
26+
'test_cpp_extensions_aot_ninja',
2727
'test_cpp_extensions_jit',
2828
'distributed/test_c10d',
2929
'distributed/test_c10d_spawn',
@@ -93,8 +93,7 @@
9393
]
9494

9595
ROCM_BLACKLIST = [
96-
'test_cpp_extensions_aot',
97-
'test_cpp_extensions_aot_no_ninja',
96+
'test_cpp_extensions_aot_ninja',
9897
'test_cpp_extensions_jit',
9998
'test_multiprocessing',
10099
'distributed/rpc/test_rpc_spawn',
@@ -129,7 +128,7 @@
129128
Ninja (https://ninja-build.org) is required for some of the C++ extensions
130129
tests, but it could not be found. Install ninja with `pip install ninja`
131130
or `conda install ninja`. Alternatively, disable said tests with
132-
`run_test.py --exclude test_cpp_extensions_aot test_cpp_extensions_jit`.
131+
`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`.
133132
"""
134133

135134

@@ -199,8 +198,8 @@ def _test_cpp_extensions_aot(executable, test_module, test_directory, options, u
199198
os.environ['PYTHONPATH'] = python_path
200199

201200

202-
def test_cpp_extensions_aot(executable, test_module, test_directory, options):
203-
return _test_cpp_extensions_aot(executable, test_module, test_directory,
201+
def test_cpp_extensions_aot_ninja(executable, test_module, test_directory, options):
202+
return _test_cpp_extensions_aot(executable, 'test_cpp_extensions_aot', test_directory,
204203
options, use_ninja=True)
205204

206205

@@ -261,8 +260,8 @@ def test_distributed(executable, test_module, test_directory, options):
261260

262261
CUSTOM_HANDLERS = {
263262
'test_cuda_primary_ctx': test_cuda_primary_ctx,
264-
'test_cpp_extensions_aot': test_cpp_extensions_aot,
265263
'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja,
264+
'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja,
266265
'distributed/test_distributed': test_distributed,
267266
}
268267

@@ -430,8 +429,8 @@ def get_selected_tests(options):
430429
if sys.platform == 'win32' and not options.ignore_win_blacklist:
431430
target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH')
432431
if target_arch != 'x64':
433-
WINDOWS_BLACKLIST.append('cpp_extensions_aot')
434432
WINDOWS_BLACKLIST.append('cpp_extensions_aot_no_ninja')
433+
WINDOWS_BLACKLIST.append('cpp_extensions_aot_ninja')
435434
WINDOWS_BLACKLIST.append('cpp_extensions_jit')
436435
WINDOWS_BLACKLIST.append('jit')
437436
WINDOWS_BLACKLIST.append('jit_fuser')

test/test_cpp_extensions_aot.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
except ImportError:
1515
raise RuntimeError(
1616
"test_cpp_extensions_aot.py cannot be invoked directly. Run "
17-
"`python run_test.py -i test_cpp_extensions_aot` instead."
17+
"`python run_test.py -i test_cpp_extensions_aot_ninja` instead."
1818
)
1919

2020

2121
class TestCppExtensionAOT(common.TestCase):
2222
"""Tests ahead-of-time cpp extensions
2323
24-
NOTE: run_test.py's test_cpp_extensions_aot_no_ninja target
25-
also runs this test case, but with ninja disabled. If you are debugging
24+
NOTE: run_test.py's test_cpp_extensions_aot_ninja target
25+
also runs this test case, but with ninja enabled. If you are debugging
2626
a test failure here from the CI, check the logs for which target
27-
(test_cpp_extensions_aot vs test_cpp_extensions_aot_no_ninja)
27+
(test_cpp_extensions_aot_no_ninja vs test_cpp_extensions_aot_ninja)
2828
failed.
2929
"""
3030

torch/utils/cpp_extension.py

+100-12
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,44 @@ def _find_cuda_home():
5050
print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
5151
return cuda_home
5252

53+
def _find_rocm_home():
54+
'''Finds the ROCm install path.'''
55+
# Guess #1
56+
rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
57+
if rocm_home is None:
58+
# Guess #2
59+
try:
60+
hipcc = subprocess.check_output(
61+
['which', 'hipcc']).decode().rstrip('\r\n')
62+
# this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
63+
rocm_home = os.path.dirname(os.path.dirname(hipcc))
64+
if os.path.basename(rocm_home) == 'hip':
65+
rocm_home = os.path.dirname(rocm_home)
66+
except Exception:
67+
# Guess #3
68+
rocm_home = '/opt/rocm'
69+
if not os.path.exists(rocm_home):
70+
rocm_home = None
71+
if rocm_home and torch.version.hip is None:
72+
print("No ROCm runtime is found, using ROCM_HOME='{}'".format(rocm_home))
73+
return rocm_home
74+
75+
76+
def _join_rocm_home(*paths):
77+
'''
78+
Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
79+
80+
This is basically a lazy way of raising an error for missing $ROCM_HOME
81+
only once we need to get any ROCm-specific path.
82+
'''
83+
if ROCM_HOME is None:
84+
raise EnvironmentError('ROCM_HOME environment variable is not set. '
85+
'Please set it to your ROCm install root.')
86+
elif IS_WINDOWS:
87+
raise EnvironmentError('Building PyTorch extensions using '
88+
'ROCm and Windows is not supported.')
89+
return os.path.join(ROCM_HOME, *paths)
90+
5391

5492
MINIMUM_GCC_VERSION = (4, 9, 0)
5593
MINIMUM_MSVC_VERSION = (19, 0, 24215)
@@ -85,6 +123,9 @@ def _find_cuda_home():
85123
86124
!! WARNING !!
87125
'''
126+
ROCM_HOME = _find_rocm_home()
127+
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
128+
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
88129
CUDA_HOME = _find_cuda_home()
89130
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
90131
# PyTorch releases have the version pattern major.minor.patch, whereas when
@@ -101,6 +142,14 @@ def _find_cuda_home():
101142
'--expt-relaxed-constexpr'
102143
]
103144

145+
COMMON_HIPCC_FLAGS = [
146+
'-fPIC',
147+
'-D__HIP_PLATFORM_HCC__=1',
148+
'-DCUDA_HAS_FP16=1',
149+
'-D__HIP_NO_HALF_OPERATORS__=1',
150+
'-D__HIP_NO_HALF_CONVERSIONS__=1',
151+
]
152+
104153
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
105154

106155

@@ -243,12 +292,15 @@ def __init__(self, *args, **kwargs):
243292
super(BuildExtension, self).__init__(*args, **kwargs)
244293
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
245294

246-
self.use_ninja = kwargs.get('use_ninja', True)
295+
self.use_ninja = kwargs.get('use_ninja', False if IS_HIP_EXTENSION else True)
247296
if self.use_ninja:
248297
# Test if we can use ninja. Fallback otherwise.
249298
msg = ('Attempted to use ninja as the BuildExtension backend but '
250299
'{}. Falling back to using the slow distutils backend.')
251-
if not _is_ninja_available():
300+
if IS_HIP_EXTENSION:
301+
warnings.warn(msg.format('HIP extensions is not supported yet for ninja.'))
302+
self.use_ninja = False
303+
elif not _is_ninja_available():
252304
warnings.warn(msg.format('we could not find ninja.'))
253305
self.use_ninja = False
254306

@@ -259,8 +311,8 @@ def build_extensions(self):
259311
self._define_torch_extension_name(extension)
260312
self._add_gnu_cpp_abi_flag(extension)
261313

262-
# Register .cu and .cuh as valid source extensions.
263-
self.compiler.src_extensions += ['.cu', '.cuh']
314+
# Register .cu, .cuh and .hip as valid source extensions.
315+
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
264316
# Save the original _compile method for later.
265317
if self.compiler.compiler_type == 'msvc':
266318
self.compiler._cpp_extensions += ['.cu', '.cuh']
@@ -289,15 +341,20 @@ def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
289341
try:
290342
original_compiler = self.compiler.compiler_so
291343
if _is_cuda_file(src):
292-
nvcc = _join_cuda_home('bin', 'nvcc')
344+
nvcc = (_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc'))
293345
if not isinstance(nvcc, list):
294346
nvcc = [nvcc]
295347
self.compiler.set_executable('compiler_so', nvcc)
296348
if isinstance(cflags, dict):
297349
cflags = cflags['nvcc']
298-
cflags = unix_cuda_flags(cflags)
350+
if IS_HIP_EXTENSION:
351+
cflags = cflags + _get_rocm_arch_flags(cflags)
352+
else:
353+
cflags = unix_cuda_flags(cflags)
299354
elif isinstance(cflags, dict):
300355
cflags = cflags['cxx']
356+
if IS_HIP_EXTENSION:
357+
cflags = cflags + COMMON_HIPCC_FLAGS
301358
append_std14_if_no_std_present(cflags)
302359

303360
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
@@ -649,13 +706,17 @@ def CUDAExtension(name, sources, *args, **kwargs):
649706
kwargs['library_dirs'] = library_dirs
650707

651708
libraries = kwargs.get('libraries', [])
652-
libraries.append('cudart')
653709
libraries.append('c10')
654-
libraries.append('c10_cuda')
655710
libraries.append('torch')
656711
libraries.append('torch_cpu')
657-
libraries.append('torch_cuda')
658712
libraries.append('torch_python')
713+
if IS_HIP_EXTENSION:
714+
libraries.append('c10_hip')
715+
libraries.append('torch_hip')
716+
else:
717+
libraries.append('cudart')
718+
libraries.append('c10_cuda')
719+
libraries.append('torch_cuda')
659720
kwargs['libraries'] = libraries
660721

661722
include_dirs = kwargs.get('include_dirs', [])
@@ -689,7 +750,12 @@ def include_paths(cuda=False):
689750
os.path.join(lib_include, 'TH'),
690751
os.path.join(lib_include, 'THC')
691752
]
692-
if cuda:
753+
if cuda and IS_HIP_EXTENSION:
754+
paths.append(os.path.join(lib_include, 'THH'))
755+
paths.append(_join_rocm_home('include'))
756+
if MIOPEN_HOME is not None:
757+
paths.append(os.path.join(MIOPEN_HOME, 'include'))
758+
elif cuda:
693759
cuda_home_include = _join_cuda_home('include')
694760
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
695761
# but gcc doesn't like having /usr/include passed explicitly
@@ -718,7 +784,10 @@ def library_paths(cuda=False):
718784
lib_path = os.path.join(torch_path, 'lib')
719785
paths.append(lib_path)
720786

721-
if cuda:
787+
if cuda and IS_HIP_EXTENSION:
788+
lib_dir = 'lib'
789+
paths.append(_join_rocm_home(lib_dir))
790+
elif cuda:
722791
if IS_WINDOWS:
723792
lib_dir = 'lib/x64'
724793
else:
@@ -1251,6 +1320,22 @@ def _get_cuda_arch_flags(cflags=None):
12511320
return list(set(flags))
12521321

12531322

1323+
def _get_rocm_arch_flags(cflags=None):
1324+
# If cflags is given, there may already be user-provided arch flags in it
1325+
# (from `extra_compile_args`)
1326+
if cflags is not None:
1327+
for flag in cflags:
1328+
if 'amdgpu-target' in flag:
1329+
return ['-fno-gpu-rdc']
1330+
return [
1331+
'--amdgpu-target=gfx803',
1332+
'--amdgpu-target=gfx900',
1333+
'--amdgpu-target=gfx906',
1334+
'--amdgpu-target=gfx908',
1335+
'-fno-gpu-rdc'
1336+
]
1337+
1338+
12541339
def _get_build_directory(name, verbose):
12551340
root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
12561341
if root_extensions_directory is None:
@@ -1567,4 +1652,7 @@ def _join_cuda_home(*paths):
15671652

15681653

15691654
def _is_cuda_file(path):
1570-
return os.path.splitext(path)[1] in ['.cu', '.cuh']
1655+
valid_ext = ['.cu', '.cuh']
1656+
if IS_HIP_EXTENSION:
1657+
valid_ext.append('.hip')
1658+
return os.path.splitext(path)[1] in valid_ext

torch/utils/hipify/hipify_python.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def preprocess(
119119
all_files,
120120
show_detailed=False,
121121
show_progress=True,
122-
hip_clang_launch=False):
122+
hip_clang_launch=False,
123+
is_pytorch_extension=False):
123124
"""
124125
Call preprocessor on selected files.
125126
@@ -131,7 +132,7 @@ def preprocess(
131132
stats = {"unsupported_calls": [], "kernel_launches": []}
132133

133134
for filepath in all_files:
134-
result = preprocessor(output_directory, filepath, stats, hip_clang_launch)
135+
result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension)
135136
# Show what happened
136137
if show_progress:
137138
print(
@@ -605,7 +606,7 @@ def pattern(self):
605606
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
606607
RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
607608

608-
def preprocessor(output_directory, filepath, stats, hip_clang_launch):
609+
def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension):
609610
""" Executes the CUDA -> HIP conversion on the specified file. """
610611
fin_path = os.path.join(output_directory, filepath)
611612
with open(fin_path, 'r') as fin:
@@ -616,14 +617,18 @@ def preprocessor(output_directory, filepath, stats, hip_clang_launch):
616617
os.makedirs(os.path.dirname(fout_path))
617618

618619
# unsupported_calls statistics reporting is broken atm
619-
if is_pytorch_file(filepath):
620-
def pt_repl(m):
621-
return PYTORCH_MAP[m.group(0)]
620+
def pt_repl(m):
621+
return PYTORCH_MAP[m.group(0)]
622+
623+
if is_pytorch_extension:
622624
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
623625
else:
624-
def c2_repl(m):
625-
return CAFFE2_MAP[m.group(0)]
626-
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
626+
if is_pytorch_file(filepath):
627+
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
628+
else:
629+
def c2_repl(m):
630+
return CAFFE2_MAP[m.group(0)]
631+
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
627632

628633
# Header rewrites
629634
def mk_repl(templ):
@@ -775,6 +780,7 @@ def hipify(
775780
ignores=(),
776781
show_progress=True,
777782
hip_clang_launch=False,
783+
is_pytorch_extension=False,
778784
):
779785
if project_directory == "":
780786
project_directory = os.getcwd()
@@ -803,4 +809,5 @@ def hipify(
803809
all_files,
804810
show_detailed=show_detailed,
805811
show_progress=show_progress,
806-
hip_clang_launch=hip_clang_launch)
812+
hip_clang_launch=hip_clang_launch,
813+
is_pytorch_extension=is_pytorch_extension)

0 commit comments

Comments
 (0)