Skip to content

Commit e09ed4d

Browse files
committed
Get and force the use of the device local to process in distributed runs
1 parent d57c8ea commit e09ed4d

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

skrl/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def __init__(self) -> None:
188188
process_id=self._rank,
189189
local_device_ids=self._local_rank,
190190
)
191+
# get the device local to process
192+
try:
193+
self._device = jax.local_devices(process_index=self._rank)[0]
194+
logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})")
195+
except Exception as e:
196+
logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}")
191197

192198
@staticmethod
193199
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
@@ -204,6 +210,15 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
204210
"""
205211
import jax
206212

213+
# force the use of the device local to process in distributed runs
214+
if config.jax.is_distributed:
215+
try:
216+
return jax.local_devices(process_index=config.jax.rank)[0]
217+
except Exception as e:
218+
logger.warning(
219+
f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}"
220+
)
221+
207222
if isinstance(device, jax.Device):
208223
return device
209224
elif isinstance(device, str):

0 commit comments

Comments
 (0)