diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 1565e49707e..7b3c6b9dcb0 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -53,38 +53,6 @@ jobs: ## setup_env.sh bash .github/unittest/linux/scripts/run_all.sh - tests-cpu-oldget: - # Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7 - strategy: - matrix: - python_version: ["3.12"] - fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - with: - runner: linux.12xlarge - repository: pytorch/rl - docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04" - timeout: 90 - script: | - if [[ "${{ github.ref }}" =~ release/* ]]; then - export RELEASE=1 - export TORCH_VERSION=stable - else - export RELEASE=0 - export TORCH_VERSION=nightly - fi - export TD_GET_DEFAULTS_TO_NONE=0 - - # Set env vars from matrix - export PYTHON_VERSION=${{ matrix.python_version }} - export CU_VERSION="cpu" - - echo "PYTHON_VERSION: $PYTHON_VERSION" - echo "CU_VERSION: $CU_VERSION" - - ## setup_env.sh - bash .github/unittest/linux/scripts/run_all.sh - tests-gpu: strategy: matrix: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 0610d1229d7..81b25edc2fd 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch. GymLikeEnv EnvMetaData +Partial steps and partial resets +-------------------------------- + +TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. +If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed +below. + +Batching environments and locking the batch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. _ref_batch_locked: + +Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has +a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an +input of arbitrary size, batches the operations over all elements (mostly stateless environments). + +This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input +tensordicts to have the same batch-size as the env's. Typical examples of these environments are +:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. +Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. + +Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the +tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous +input. + +Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with +partial steps easily, they just pass the actions to the sub-environments that are required to be executed. + +In all other cases, TorchRL assumes that the environment handles the partial steps correctly. + +.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl + to control what happens within the `_step` method! + +Partial Steps +~~~~~~~~~~~~~ + +.. _ref_partial_steps: + +Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the +size of the tensordict that holds it. The classes armed to deal with this are: + +- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the + action to and only to the environments where `"_step"` is `True`; +- Batch-unlocked environments; +- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` + method will first look for a `"_step"` entry and, if present, act accordingly. + If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by + :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further + transformation. + +When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous +content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that +if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for +all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that +is not observed because these classes handle the passing of data properly. + +Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, +as the environments with a `True` done state will need to be skipped during calls to `_step`. + +The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. + +Partial Resets +~~~~~~~~~~~~~~ + +.. _ref_partial_resets: + +Partial resets work pretty much like partial steps, but with the `"_reset"` entry. + +The same restrictions of partial steps apply to partial resets. + +Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, +as the environments with a `True` done state will need to be reset, but not others. + +See te following paragraph for a deep dive in partial resets within batched and vectorized environments. + Vectorized envs --------------- @@ -886,6 +961,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + ConditionalSkip Crop DTypeCastTransform DeviceCastTransform diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 624a0f098e1..d933cb72216 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -538,6 +538,8 @@ def _step( class ContinuousActionVecMockEnv(_MockEnv): + adapt_dtype: bool = True + @classmethod def __new__( cls, @@ -635,7 +637,14 @@ def _step( while done.shape != tensordict.shape: done = done.any(-1) done = reward = done.unsqueeze(-1) - tensordict.set("reward", reward.to(torch.get_default_dtype())) + tensordict.set( + "reward", + reward.to( + self.reward_spec.dtype + if self.adapt_dtype + else torch.get_default_dtype() + ).expand(self.reward_spec.shape), + ) tensordict.set("done", done) tensordict.set("terminated", done) return tensordict diff --git a/test/test_env.py b/test/test_env.py index 3e4bb0febca..ad02467d6ab 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -235,6 +235,7 @@ class TestEnvBase: def test_run_type_checks(self): env = ContinuousActionVecMockEnv() + env.adapt_dtype = False env._run_type_checks = False check_env_specs(env) env._run_type_checks = True @@ -4112,17 +4113,21 @@ def test_parallel_partial_steps( use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td = penv.step(td) - assert (td[0].get("next") == 0).all() - assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td = penv.step(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_parallel_partial_step_and_maybe_reset( @@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset( use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td, tdreset = penv.step_and_maybe_reset(td) - assert (td[0].get("next") == 0).all() - assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool, device=td.get("done").device) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td, tdreset = penv.step_and_maybe_reset(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_steps(self, use_buffers, device, env_device): @@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device): use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td = penv.step(td) - assert (td[0].get("next") == 0).all() - assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td = penv.step(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): @@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi td.set("action", penv.full_action_spec[penv.action_key].one()) td = penv.step(td) - assert (td[0].get("next") == 0).all() + assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) assert (td[3].get("next") != 0).any() diff --git a/test/test_transforms.py b/test/test_transforms.py index ff910434c5d..de59fee8c4b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import abc import argparse +import collections import contextlib import importlib.util @@ -58,6 +59,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, @@ -13451,6 +13453,217 @@ def test_composite_reward_spec(self) -> None: assert transform.transform_reward_spec(reward_spec) == expected_reward_spec +class TestConditionalSkip(TransformBase): + def check_non_tensor_match(self, td): + q = collections.deque() + obs_str = td["obs_str"] + obs = td["observation"] + q.extend(list(zip(obs_str, obs.unbind(0)))) + next_obs_str = td["next", "obs_str"] + next_obs = td["next", "observation"] + q.extend(zip(next_obs_str, next_obs.unbind(0))) + while len(q): + o_str, o = q.popleft() + if isinstance(o_str, list): + q.extend(zip(o_str, o.unbind(0))) + else: + assert o_str == str(o), (obs, obs_str, next_obs, next_obs_str) + + class ToString(Transform): + def _apply_transform(self, obs: torch.Tensor) -> None: + return NonTensorData(str(obs), device=obs.device) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return self._call(tensordict_reset) + + def transform_observation_spec( + self, observation_spec: TensorSpec + ) -> TensorSpec: + observation_spec["obs_str"] = NonTensor( + example_data="a string!", + shape=observation_spec.shape, + device=observation_spec.device, + ) + return observation_spec + + class CountinEnvWithString(TransformedEnv): + def __init__(self, *args, **kwargs): + base_env = CountingEnv() + super().__init__( + base_env, + TestConditionalSkip.ToString( + in_keys=["observation"], out_keys=["obs_str"] + ), + ) + + @pytest.mark.parametrize("bwad", [False, True]) + def test_single_trans_env_check(self, bwad): + env = TestConditionalSkip.CountinEnvWithString() + base_env = TransformedEnv( + env, + Compose( + StepCounter(step_count_key="other_count"), + ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1), + ), + ) + env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) + env.set_seed(0) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((1,))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1)).all() + assert (r["other_count"] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) + + @pytest.mark.parametrize("bwad", [False, True]) + def test_serial_trans_env_check(self, bwad): + def make_env(i): + env = TestConditionalSkip.CountinEnvWithString() + base_env = TransformedEnv( + env, + Compose( + StepCounter(step_count_key="other_count"), + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), + ), + ) + return TransformedEnv( + base_env, + StepCounter(), + auto_unwrap=False, + ) + + env = SerialEnv(2, [partial(make_env, i=0), partial(make_env, i=1)]) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) + + @pytest.mark.parametrize("bwad", [False, True]) + def test_parallel_trans_env_check(self, bwad): + def make_env(i): + env = TestConditionalSkip.CountinEnvWithString() + base_env = TransformedEnv( + env, + Compose( + StepCounter(step_count_key="other_count"), + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), + ), + ) + return TransformedEnv( + base_env, + StepCounter(), + auto_unwrap=False, + ) + + env = ParallelEnv( + 2, [partial(make_env, i=0), partial(make_env, i=1)], mp_start_method=mp_ctx + ) + try: + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert ( + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) + ).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) + finally: + env.close() + del env + + @pytest.mark.parametrize("bwad", [False, True]) + def test_trans_serial_env_check(self, bwad): + def make_env(): + env = TestConditionalSkip.CountinEnvWithString(max_steps=100) + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return base_env + + base_env = SerialEnv(2, [make_env, make_env]) + + def cond(td): + sc = td["step_count"] + torch.tensor([[0], [1]]) + return sc.squeeze() % 2 == 0 + + env = TransformedEnv(base_env, ConditionalSkip(cond)) + env = TransformedEnv(env, StepCounter(), auto_unwrap=False) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) + + @pytest.mark.parametrize("bwad", [True, False]) + @pytest.mark.parametrize("buffers", [True, False]) + def test_trans_parallel_env_check(self, bwad, buffers): + def make_env(): + env = TestConditionalSkip.CountinEnvWithString(max_steps=100) + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return base_env + + base_env = ParallelEnv( + 2, [make_env, make_env], mp_start_method=mp_ctx, use_buffers=buffers + ) + try: + + def cond(td): + sc = td["step_count"] + torch.tensor([[0], [1]]) + return sc.squeeze() % 2 == 0 + + env = TransformedEnv(base_env, ConditionalSkip(cond)) + env = TransformedEnv(env, StepCounter(), auto_unwrap=False) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert ( + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) + ).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) + finally: + base_env.close() + del base_env + + def test_transform_no_env(self): + t = ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + assert not t._inv_call(TensorDict())["_step"] + assert t._inv_call(TensorDict())["_step"].shape == () + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) + + def test_transform_compose(self): + t = Compose( + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + ) + assert not t._inv_call(TensorDict())["_step"] + assert t._inv_call(TensorDict())["_step"].shape == () + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) + + def test_transform_env(self): + # tested above + return + + def test_transform_model(self): + t = Compose( + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + ) + with pytest.raises(NotImplementedError): + t(TensorDict())["_step"] + + def test_transform_rb(self): + return + + def test_transform_inverse(self): + return + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 52f0cfdbbc5..6db86c9b0ab 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -56,6 +56,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 963c640c322..7bc9b0c1a5a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -29,6 +29,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import _is_leaf_nontensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( @@ -1089,7 +1090,7 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1162,9 +1163,31 @@ def select_and_clone(name, tensor): out_tds.append(out_td) out = LazyStackedTensorDict.maybe_dense_stack(out_tds) - if partial_steps is not None: + if partial_steps is not None and not partial_steps.all(): result = out.new_zeros(tensordict_save.shape) - result[partial_steps] = out + # Copy the observation data from the previous step as placeholder + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, + ) + + result.update(prev) + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -1529,7 +1552,29 @@ def _step_and_maybe_reset_no_buffers( ) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result[partial_steps] = out + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, + ) + + result.update(prev) + + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -1543,7 +1588,7 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1608,10 +1653,11 @@ def step_and_maybe_reset( data = [{} for _ in workers_range] if self._non_tensor_keys: - for i in workers_range: - data[i]["non_tensor_data"] = tensordict[i].select( - *self._non_tensor_keys, strict=False - ) + for i, td in zip( + workers_range, + tensordict.select(*self._non_tensor_keys, strict=False).unbind(0), + ): + data[i]["non_tensor_data"] = td self._sync_m2w() for i, _data in zip(workers_range, data): @@ -1645,6 +1691,14 @@ def step_and_maybe_reset( device=device, filter_empty=True, ) + if tensordict.device != device: + tensordict = tensordict._fast_apply( + lambda x: x.to(device, non_blocking=self.non_blocking) + if x.device != device + else x, + device=device, + filter_empty=True, + ) self._sync_w2m() else: next_td = next_td.clone().clear_device_() @@ -1661,8 +1715,47 @@ def step_and_maybe_reset( if partial_steps is not None: result = tensordict.new_zeros(tensordict_save.shape) result_ = tensordict_.new_zeros(tensordict_save.shape) - result[partial_steps] = tensordict - result_[partial_steps] = tensordict_ + + def select_and_transfer(x, y): + if y is not None: + return ( + x.to(y.device, non_blocking=self.non_blocking) + if x.device != y.device + else x.clone() + ) + + old_r_copy = tensordict_save._fast_apply( + select_and_transfer, + result, + filter_empty=True, + device=device, + default=None, + ) + old_r_copy.set( + "next", + tensordict_save._fast_apply( + select_and_transfer, + next_td, + filter_empty=True, + device=device, + default=None, + ), + ) + result.update(old_r_copy) + result_.update( + tensordict_save._fast_apply( + select_and_transfer, + result_, + filter_empty=True, + device=device, + default=None, + ) + ) + self._sync_w2m() + + if partial_steps.any(): + result[partial_steps] = tensordict + result_[partial_steps] = tensordict_ return result, result_ return tensordict, tensordict_ @@ -1700,7 +1793,7 @@ def _wait_for_workers(self, workers_range): def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1727,7 +1820,29 @@ def _step_no_buffers( out = out.to(self.device, non_blocking=self.non_blocking) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result[partial_steps] = out + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, + ) + + result.update(prev) + + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -1744,7 +1859,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1774,6 +1889,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: shared_tensordict_parent.update_( tensordict, + # We also update the output keys because they can be implicitly used, eg + # during partial steps to fill in values keys_to_update=list(self._env_input_keys), non_blocking=self.non_blocking, ) @@ -1820,10 +1937,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: - for i in workers_range: - data[i]["non_tensor_data"] = tensordict[i].select( - *self._non_tensor_keys, strict=False - ) + for i, td in zip( + workers_range, + tensordict.select(*self._non_tensor_keys, strict=False).unbind(0), + ): + data[i]["non_tensor_data"] = td self._sync_m2w() @@ -1875,7 +1993,28 @@ def select_and_clone(name, tensor): self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result[partial_steps] = out + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, + ) + + result.update(prev) + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -2284,13 +2423,18 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if non_tensor_data is not None: input.update(non_tensor_data) - next_td = env._step(input) + input = env.step(input.copy()) + next_td = input.get("next") next_shared_tensordict.update_(next_td, non_blocking=non_blocking) + if event is not None: event.record() event.synchronize() mp_event.set() + # Make sure the root is updated + root_shared_tensordict.update_(env._step_mdp(input)) + if _non_tensor_keys: child_pipe.send( ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 4690772db7c..1bb8de1ca8e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1933,6 +1933,32 @@ def state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.state_spec = spec + def _skip_tensordict(self, tensordict): + # Creates a "skip" tensordict, ie a placeholder for when a step is skipped + next_tensordict = self.full_done_spec.zero() + next_tensordict.update(self.full_observation_spec.zero()) + next_tensordict.update(self.full_reward_spec.zero()) + + # Copy the data from tensordict in `next` + def select_and_clone(x, y): + if y is not None: + if y.device == x.device: + return x.clone() + return x.to(y.device) + + next_tensordict.update( + tensordict._fast_apply( + select_and_clone, + next_tensordict, + device=next_tensordict.device, + batch_size=next_tensordict.batch_size, + default=None, + filter_empty=True, + is_leaf=_is_leaf_nontensor, + ) + ) + return next_tensordict + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. @@ -1953,25 +1979,33 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) - partial_steps = None + partial_steps = tensordict.pop("_step", None) - if not self.batch_locked: - # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here - partial_steps = tensordict.get("_step", None) - if partial_steps is not None: + next_tensordict = None + + if partial_steps is not None: + if not self.batch_locked: + # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here if partial_steps.all(): partial_steps = None else: tensordict_batch_size = tensordict.batch_size partial_steps = partial_steps.view(tensordict_batch_size) tensordict = tensordict[partial_steps] - else: + else: + if not partial_steps.any(): + next_tensordict = self._skip_tensordic(tensordict) + else: + # trust that the _step can handle this! + tensordict.set("_step", partial_steps) + tensordict_batch_size = self.batch_size next_preset = tensordict.get("next", None) - next_tensordict = self._step(tensordict) - next_tensordict = self._step_proc_data(next_tensordict) + if next_tensordict is None: + next_tensordict = self._step(tensordict) + next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: # tensordict could already have a "next" key # this could be done more efficiently by not excluding but just passing @@ -1980,9 +2014,28 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: next_preset.exclude(*next_tensordict.keys(True, True)) ) tensordict.set("next", next_tensordict) - if partial_steps is not None: + if partial_steps is not None and tensordict_batch_size != self.batch_size: result = tensordict.new_zeros(tensordict_batch_size) - result[partial_steps] = tensordict + + def select_and_clone(x, y): + if y is not None: + if x.device == y.device: + return x.clone() + return x.to(y.device) + + result.update( + tensordict._fast_apply( + select_and_clone, + result, + device=result.device, + filter_empty=True, + default=None, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + ) + ) + if partial_steps.any(): + result[partial_steps] = tensordict return result return tensordict @@ -2873,9 +2926,17 @@ def rollout( policy device before the policy is used. Default is ``False``. break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the done states. If ``False``, then the done environments are reset automatically. Default is ``True``. + + .. seealso:: The :ref:`Partial resets ` of the documentation gives more + information about partial resets. + break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any of the done states. If ``False``, break if at least one environment reaches any of the done states. Default is ``False``. + + .. seealso:: The :ref:`Partial steps ` of the documentation gives more + information about partial resets. + return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if the env does not have dynamic specs, otherwise `False`. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7ee142fe811..12d5d03508f 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7b3bab47227..40878e15bd7 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -44,6 +44,7 @@ unravel_key, unravel_key_list, ) +from tensordict.base import _is_leaf_nontensor from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( _unravel_key_to_tuple, @@ -635,7 +636,7 @@ def parent(self) -> Optional[EnvBase]: ) parent, _ = container._rebuild_up_to(self) elif isinstance(container, TransformedEnv): - parent = TransformedEnv(container.base_env) + parent = TransformedEnv(container.base_env, auto_unwrap=False) else: raise ValueError(f"container is of type {type(container)}") self.__dict__["_parent"] = parent @@ -693,10 +694,19 @@ class TransformedEnv(EnvBase, metaclass=_TEnvPostInit): in which case this value should be set to `False`. Default is `True`. + Keyword Args: + auto_unwrap (bool, optional): if ``True``, wrapping a transformed env in transformed env + unwraps the transforms of the inner TransformedEnv in the outer one (the new instance). + Defaults to ``True`` + Examples: >>> env = GymEnv("Pendulum-v0") >>> transform = RewardScaling(0.0, 1.0) >>> transformed_env = TransformedEnv(env, transform) + >>> # check auto-unwrap + >>> transformed_env = TransformedEnv(transformed_env, StepCounter()) + >>> # The inner env has been unwrapped + >>> assert isinstance(transformed_env.base_env, GymEnv) """ @@ -705,6 +715,8 @@ def __init__( env: EnvBase, transform: Optional[Transform] = None, cache_specs: bool = True, + *, + auto_unwrap: bool = True, **kwargs, ): self._transform = None @@ -717,7 +729,7 @@ def __init__( # Type matching must be exact here, because subtyping could introduce differences in behavior that must # be contained within the subclass. - if type(env) is TransformedEnv and type(self) is TransformedEnv: + if type(env) is TransformedEnv and type(self) is TransformedEnv and auto_unwrap: self._set_env(env.base_env, device) if type(transform) is not Compose: # we don't use isinstance as some transforms may be subclassed from @@ -923,17 +935,70 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # tensordict = tensordict.clone(False) next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) - next_tensordict = self.base_env._step(tensordict_in) - if next_preset is not None: - # tensordict could already have a "next" key - # this could be done more efficiently by not excluding but just passing - # the necessary keys - next_tensordict.update( - next_preset.exclude(*next_tensordict.keys(True, True)) - ) - self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) - # we want the input entries to remain unchanged - next_tensordict = self.transform._step(tensordict, next_tensordict) + + # It could be that the step must be skipped + partial_steps = tensordict_in.pop("_step", None) + next_tensordict = None + tensordict_batch_size = None + if partial_steps is not None: + if not self.batch_locked: + # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here + if partial_steps.all(): + partial_steps = None + else: + tensordict_batch_size = tensordict_in.batch_size + partial_steps = partial_steps.view(tensordict_batch_size) + tensordict_in_save = tensordict_in[~partial_steps] + tensordict_in = tensordict_in[partial_steps] + else: + if not partial_steps.any(): + next_tensordict = self._skip_tensordict(tensordict_in) + # No need to copy anything + partial_steps = None + elif not partial_steps.all(): + # trust that the _step can handle this! + tensordict_in.set("_step", partial_steps) + # The filling should be handled by the sub-env + partial_steps = None + else: + partial_steps = None + tensordict_batch_size = self.batch_size + + if next_tensordict is None: + next_tensordict = self.base_env._step(tensordict_in) + if next_preset is not None: + # tensordict could already have a "next" key + # this could be done more efficiently by not excluding but just passing + # the necessary keys + next_tensordict.update( + next_preset.exclude(*next_tensordict.keys(True, True)) + ) + self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) + # we want the input entries to remain unchanged + next_tensordict = self.transform._step(tensordict, next_tensordict) + + if partial_steps is not None: + result = next_tensordict.new_zeros(tensordict_batch_size) + + def select_and_clone(x, y): + if y is not None: + if x.device == y.device: + return x.clone() + return x.to(y.device) + + if not partial_steps.all(): + result[~partial_steps] = tensordict_in_save._fast_apply( + select_and_clone, + result, + device=result.device, + filter_empty=True, + default=None, + is_leaf=_is_leaf_nontensor, + ) + if partial_steps.any(): + result[partial_steps] = next_tensordict + next_tensordict = result + return next_tensordict def set_seed( @@ -1410,7 +1475,7 @@ def _rebuild_up_to(self, final_transform): # returns None if there is no parent env return None, None elif isinstance(container, TransformedEnv): - out = TransformedEnv(container.base_env) + out = TransformedEnv(container.base_env, auto_unwrap=False) elif container is None: # returns None if there is no parent env return None, None @@ -10218,3 +10283,118 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase: ) return (self.weights * reward).sum(dim=-1) + + +class ConditionalSkip(Transform): + """A transform that skips steps in the env if certain conditions are met. + + This transform writes the result of `cond(tensordict)` in the `"_step"` entry of the + tensordict passed as input to the `TransformedEnv.base_env._step` method. + If the `base_env` is not batch-locked (generally speaking, it is stateless), the tensordict is + reduced to its element that need to go through the step. + If it is batch-locked (generally speaking, it is stateful), the step is skipped altogether if no + value in `"_step"` is ``True``. Otherwise, it is trusted that the environment will account for the + `"_step"` signal accordingly. + + .. note:: The skip will affect transforms that modify the environment output too, i.e., any transform + that is to be exectued on the tensordict returned by :meth:`~torchrl.envs.EnvBase.step` will be + skipped if the condition is met. To palliate this effect if it is not desirable, one can wrap + the transformed env in another transformed env, since the skip only affects the first-degree parent + of the ``ConditionalSkip`` transform. See example below. + + Args: + cond (Callable[[TensorDictBase], bool | torch.Tensor]): a callable for the tensordict input + that checks whether the next env step must be skipped (`True` = skipped, `False` = execute + env.step). + + Examples: + >>> import torch + >>> + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose + >>> + >>> torch.manual_seed(0) + >>> + >>> base_env = TransformedEnv( + ... GymEnv("Pendulum-v1"), + ... StepCounter(step_count_key="inner_count"), + ... ) + >>> middle_env = TransformedEnv( + ... base_env, + ... Compose( + ... StepCounter(step_count_key="middle_count"), + ... ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1), + ... ), + ... auto_unwrap=False) # makes sure that transformed envs are properly wrapped + >>> env = TransformedEnv( + ... middle_env, + ... StepCounter(step_count_key="step_count"), + ... auto_unwrap=False) + >>> env.set_seed(0) + >>> + >>> r = env.rollout(10) + >>> print(r["observation"]) + tensor([[-0.9670, -0.2546, -0.9669], + [-0.9802, -0.1981, -1.1601], + [-0.9802, -0.1981, -1.1601], + [-0.9926, -0.1214, -1.5556], + [-0.9926, -0.1214, -1.5556], + [-0.9994, -0.0335, -1.7622], + [-0.9994, -0.0335, -1.7622], + [-0.9984, 0.0561, -1.7933], + [-0.9984, 0.0561, -1.7933], + [-0.9895, 0.1445, -1.7779]]) + >>> print(r["inner_count"]) + tensor([[0], + [1], + [1], + [2], + [2], + [3], + [3], + [4], + [4], + [5]]) + >>> print(r["middle_count"]) + tensor([[0], + [1], + [1], + [2], + [2], + [3], + [3], + [4], + [4], + [5]]) + >>> print(r["step_count"]) + tensor([[0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9]]) + + + """ + + def __init__(self, cond: Callable[[TensorDict], bool | torch.Tensor]): + super().__init__() + self.cond = cond + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + # Run cond + cond = self.cond(tensordict) + # Write result in step + tensordict["_step"] = tensordict.get("_step", True) & ~cond + if not tensordict["_step"].shape == tensordict.batch_size: + tensordict["_step"] = tensordict["_step"].view(tensordict.batch_size) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise NotImplementedError( + FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__) + )