From e09ed4da71f03e0ef05542cac62f8e62e9ca7c47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Fri, 24 Jan 2025 18:20:15 -0500 Subject: [PATCH 1/3] Get and force the use of the device local to process in distributed runs --- skrl/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/skrl/__init__.py b/skrl/__init__.py index bd424ffe..91d113d7 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -188,6 +188,12 @@ def __init__(self) -> None: process_id=self._rank, local_device_ids=self._local_rank, ) + # get the device local to process + try: + self._device = jax.local_devices(process_index=self._rank)[0] + logger.info(f"Using device local to process with index/rank {self._rank} ({self._device})") + except Exception as e: + logger.warning(f"Failed to get the device local to process with index/rank {self._rank}: {e}") @staticmethod def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": @@ -204,6 +210,15 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": """ import jax + # force the use of the device local to process in distributed runs + if config.jax.is_distributed: + try: + return jax.local_devices(process_index=config.jax.rank)[0] + except Exception as e: + logger.warning( + f"Failed to get the device local to process with index/rank {config.jax.rank}: {e}" + ) + if isinstance(device, jax.Device): return device elif isinstance(device, str): From f830be9dbca7df5fa5a628a1eb3d26d4f23f1ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Fri, 24 Jan 2025 18:46:49 -0500 Subject: [PATCH 2/3] Update docstrings and docs --- docs/source/api/config/frameworks.rst | 5 +++-- skrl/__init__.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index 72095643..15d773cd 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -92,11 +92,12 @@ API .. py:data:: skrl.config.jax.device :type: jax.Device - :value: "cuda:${LOCAL_RANK}" | "cpu" + :value: "cuda:${JAX_LOCAL_RANK}" | "cpu" Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. .. py:data:: skrl.config.jax.backend :type: str diff --git a/skrl/__init__.py b/skrl/__init__.py index 91d113d7..5931b8eb 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -203,6 +203,10 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``). + .. warning:: + + This method returns (forces to use) the device local to process in a distributed environment. + :param device: Device specification. If the specified device is ``None`` or it cannot be resolved, the default available device will be returned instead. @@ -233,8 +237,8 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device": def device(self) -> "jax.Device": """Default device. - The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) - if CUDA is available, ``cpu`` otherwise. + The default device, unless specified, is ``cuda:0`` if CUDA is available, ``cpu`` otherwise. + However, in a distributed environment, it is the device local to process with index ``JAX_RANK``. """ self._device = self.parse_device(self._device) return self._device From 1f4bb66e4dd0d23323e8a70a8e747cb708626399 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 27 Jan 2025 19:39:25 -0500 Subject: [PATCH 3/3] Increase PATCH version and update CHANGELOG --- .github/ISSUE_TEMPLATE/bug_report.yaml | 1 + CHANGELOG.md | 4 ++++ docs/source/conf.py | 2 +- pyproject.toml | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index c303faf0..33459361 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -30,6 +30,7 @@ body: description: The skrl version can be obtained with the command `pip show skrl`. options: - --- + - 1.4.1 - 1.4.0 - 1.3.0 - 1.2.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index e2b81d65..5ae853ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.4.1] - Unreleased +### Fixed +- Force the use of the device local to process in distributed runs in JAX + ## [1.4.0] - 2025-01-16 ### Added - Utilities to operate on Gymnasium spaces (`Box`, `Discrete`, `MultiDiscrete`, `Tuple` and `Dict`) diff --git a/docs/source/conf.py b/docs/source/conf.py index b7670098..43c9c803 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,7 @@ if skrl.__version__ != "unknown": release = version = skrl.__version__ else: - release = version = "1.4.0" + release = version = "1.4.1" master_doc = "index" diff --git a/pyproject.toml b/pyproject.toml index 5e0da0f2..4ca9ef33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skrl" -version = "1.4.0" +version = "1.4.1" description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" readme = "README.md" requires-python = ">=3.6"