Skip to content

Commit f830be9

Browse files
committed
Update docstrings and docs
1 parent e09ed4d commit f830be9

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

docs/source/api/config/frameworks.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ API
9292

9393
.. py:data:: skrl.config.jax.device
9494
:type: jax.Device
95-
:value: "cuda:${LOCAL_RANK}" | "cpu"
95+
:value: "cuda:${JAX_LOCAL_RANK}" | "cpu"
9696

9797
Default device.
9898

99-
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.
99+
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
100+
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.
100101

101102
.. py:data:: skrl.config.jax.backend
102103
:type: str

skrl/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
203203
204204
This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).
205205
206+
.. warning::
207+
208+
This method returns (forces to use) the device local to process in a distributed environment.
209+
206210
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
207211
the default available device will be returned instead.
208212
@@ -233,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
233237
def device(self) -> "jax.Device":
234238
"""Default device.
235239
236-
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
237-
if CUDA is available, ``cpu`` otherwise.
240+
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
241+
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.
238242
"""
239243
self._device = self.parse_device(self._device)
240244
return self._device

0 commit comments

Comments
 (0)