Skip to content

Commit 2c901b3

Browse files
authored
Triaging ROCm wheel build (#2161)
* Enable ROCm support in build workflow and specify runner configuration for MI300 GPU * Refactor HIP source directory handling in setup.py and remove deprecated runner configuration from build workflow * Refactor HIP source collection in setup.py for improved readability
1 parent 81e48a3 commit 2c901b3

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

.github/workflows/build_wheels_linux.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
os: linux
2929
with-cpu: enable
3030
with-cuda: enable
31-
with-rocm: disable
31+
with-rocm: enable
3232
with-xpu: enable
3333
# Note: if free-threaded python is required add py3.13t here
3434
python-versions: '["3.9"]'

setup.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,17 @@ def get_extensions():
311311
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
312312
)
313313

314-
extensions_hip_dir = os.path.join(
315-
extensions_dir, "cuda", "tensor_core_tiled_layout"
316-
)
317-
hip_sources = list(
318-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
319-
)
320-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
321-
hip_sources += list(
322-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
323-
)
314+
# Define HIP source directories
315+
hip_source_dirs = [
316+
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout"),
317+
# TODO: Add sparse_marlin back in once we have a ROCm build for it
318+
# os.path.join(extensions_dir, "cuda", "sparse_marlin")
319+
]
320+
321+
# Collect all HIP sources from the defined directories
322+
hip_sources = []
323+
for hip_dir in hip_source_dirs:
324+
hip_sources.extend(glob.glob(os.path.join(hip_dir, "*.cu"), recursive=True))
324325

325326
# Collect CUDA source files if needed
326327
if not IS_ROCM and use_cuda:

0 commit comments

Comments
 (0)