Skip to content

Commit

Permalink
Update docstrings and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 24, 2025
1 parent e09ed4d commit f830be9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 3 additions & 2 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ API

.. py:data:: skrl.config.jax.device
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"
:value: "cuda:${JAX_LOCAL_RANK}" | "cpu"

Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.

.. py:data:: skrl.config.jax.backend
:type: str
Expand Down
8 changes: 6 additions & 2 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).
.. warning::
This method returns (forces to use) the device local to process in a distributed environment.
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.
Expand Down Expand Up @@ -233,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
def device(self) -> "jax.Device":
"""Default device.
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise.
The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise.
However, in a distributed environment, it is the device local to process with index ``JAX_RANK``.
"""
self._device = self.parse_device(self._device)
return self._device
Expand Down

0 comments on commit f830be9

Please sign in to comment.