From f830be9dbca7df5fa5a628a1eb3d26d4f23f1ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Fri, 24 Jan 2025 18:46:49 -0500 Subject: [PATCH] Update docstrings and docs --- docs/source/api/config/frameworks.rst | 5 +++-- skrl/__init__.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index 72095643..15d773cd 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -92,11 +92,12 @@ API .. py:data:: skrl.config.jax.device :type: jax.Device - :value: "cuda:${LOCAL_RANK}" | "cpu" + :value: "cuda:${JAX_LOCAL_RANK}" | "cpu" Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. .. py:data:: skrl.config.jax.backend :type: str diff --git a/skrl/__init__.py b/skrl/__init__.py index 91d113d7..5931b8eb 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -203,6 +203,10 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``). + .. warning:: + + This method returns (forces to use) the device local to process in a distributed environment. + :param device: Device specification. If the specified device is ``None`` or it cannot be resolved, the default available device will be returned instead. @@ -233,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": def device(self) -> "jax.Device": """Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) - if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. """ self._device = self.parse_device(self._device) return self._device