Skip to content

Commit add4a42

Browse files
tbennunjataylo
authored andcommitted
Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery (pytorch#140320)
Fixes pytorch#140318 Pull Request resolved: pytorch#140320 Approved by: https://github.com/eqy, https://github.com/jithunnair-amd, https://github.com/jataylo, https://github.com/jeffdaily Co-authored-by: Jack Taylor <[email protected]>
1 parent 37c4b19 commit add4a42

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

test/test_cuda.py

+2
Original file line numberDiff line numberDiff line change
@@ -3300,6 +3300,8 @@ def test_hip_device_count(self):
33003300
{"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
33013301
{"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
33023302
{"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3303+
{"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3304+
{"ROCR_VISIBLE_DEVICES": "0"},
33033305
]
33043306

33053307
for env_config in custom_envs:

torch/cuda/__init__.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,24 @@ def _parse_visible_devices() -> Union[List[int], List[str]]:
646646

647647
if torch.version.hip:
648648
hip_devices = os.getenv("HIP_VISIBLE_DEVICES")
649+
rocr_devices = os.getenv("ROCR_VISIBLE_DEVICES")
650+
651+
if rocr_devices is not None:
652+
# Mostly required for ROCm to make sure ROCr visible devices
653+
# is respected, this ensures we do not return a list of devices
654+
# that exceeds the total available supplied via ROCR_VISIBLE_DEVICES
655+
var = rocr_devices
656+
649657
if hip_devices is not None:
650-
var = hip_devices
658+
# If ROCr devices have been set, the hip visible devices would
659+
# be a subset of those. HIP_VISIBLE_DEVICES can only contain
660+
# integer indices so we can use the ROCr visible devices as a key
661+
if rocr_devices is not None:
662+
hip_device_list = [int(dev) for dev in hip_devices.split(",")]
663+
rocr_device_list = rocr_devices.split(",")
664+
var = ",".join(rocr_device_list[dev] for dev in hip_device_list)
665+
else:
666+
var = hip_devices
651667

652668
if var is None:
653669
return list(range(64))

0 commit comments

Comments
 (0)