From c5770b0a575e07d16a27df1cb61ccf5682a42dcc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 16:25:35 +0000 Subject: [PATCH 01/14] Update [ghstack-poisoned] --- test/test_transforms.py | 154 ++++++++++++++++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/batched_envs.py | 29 ++++- torchrl/envs/common.py | 37 +++++-- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 130 +++++++++++++++++++++- 6 files changed, 336 insertions(+), 16 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index ff910434c5d..5d7fa5d325d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -58,6 +58,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, @@ -13451,6 +13452,159 @@ def test_composite_reward_spec(self) -> None: assert transform.transform_reward_spec(reward_spec) == expected_reward_spec +class TestConditionalSkip(TransformBase): + @pytest.mark.parametrize("bwad", [False, True]) + def test_single_trans_env_check(self, bwad): + env = CountingEnv() + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) + env = env.append_transform( + ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1) + ) + env.set_seed(0) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((1,))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1)).all() + assert (r["other_count"] == torch.arange(1, 11).view(10, 1) // 2).all() + + @pytest.mark.parametrize("bwad", [False, True]) + def test_serial_trans_env_check(self, bwad): + def make_env(i): + env = CountingEnv() + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return TransformedEnv( + base_env, + Compose( + StepCounter(), + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), + ), + auto_unwrap=False, + ) + + env = SerialEnv(2, [partial(make_env, i=0), partial(make_env, i=1)]) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + + @pytest.mark.parametrize("bwad", [False, True]) + def test_parallel_trans_env_check(self, bwad): + def make_env(i): + env = CountingEnv() + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return TransformedEnv( + base_env, + Compose( + StepCounter(), + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), + ), + auto_unwrap=False, + ) + + env = ParallelEnv( + 2, [partial(make_env, i=0), partial(make_env, i=1)], mp_start_method=mp_ctx + ) + try: + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert ( + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) + ).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + finally: + env.close() + del env + + @pytest.mark.parametrize("bwad", [False, True]) + def test_trans_serial_env_check(self, bwad): + def make_env(): + env = CountingEnv(max_steps=100) + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return base_env + + base_env = SerialEnv(2, [make_env, make_env]) + + def cond(td): + sc = td["step_count"] + torch.tensor([[0], [1]]) + return sc.squeeze() % 2 == 0 + + env = TransformedEnv(base_env, Compose(StepCounter(), ConditionalSkip(cond))) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + + @pytest.mark.parametrize("bwad", [True, False]) + @pytest.mark.parametrize("buffers", [True, False]) + def test_trans_parallel_env_check(self, bwad, buffers): + def make_env(): + env = CountingEnv(max_steps=100) + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) + return base_env + + base_env = ParallelEnv( + 2, [make_env, make_env], mp_start_method=mp_ctx, use_buffers=buffers + ) + try: + + def cond(td): + sc = td["step_count"] + torch.tensor([[0], [1]]) + return sc.squeeze() % 2 == 0 + + env = TransformedEnv( + base_env, Compose(StepCounter(), ConditionalSkip(cond)) + ) + env.check_env_specs() + policy = lambda td: td.set("action", torch.ones((2, 1))) + r = env.rollout(10, policy, break_when_any_done=bwad) + assert ( + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) + ).all() + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + finally: + base_env.close() + del base_env + + def test_transform_no_env(self): + t = ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + assert not t._inv_call(TensorDict())["_step"] + assert t._inv_call(TensorDict())["_step"].shape == () + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) + + def test_transform_compose(self): + t = Compose( + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + ) + assert not t._inv_call(TensorDict())["_step"] + assert t._inv_call(TensorDict())["_step"].shape == () + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) + + def test_transform_env(self): + # tested above + return + + def test_transform_model(self): + t = Compose( + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) + ) + with pytest.raises(NotImplementedError): + t(TensorDict())["_step"] + + def test_transform_rb(self): + return + + def test_transform_inverse(self): + return + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 52f0cfdbbc5..6db86c9b0ab 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -56,6 +56,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5475f42c61a..47163928ca0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1089,7 +1089,7 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1164,6 +1164,10 @@ def select_and_clone(name, tensor): if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) + # Copy the observation data from the previous step as placeholder + result.update( + tensordict_save.select(*result.keys(True, True), strict=False).clone() + ) result[partial_steps] = out return result @@ -1529,6 +1533,9 @@ def _step_and_maybe_reset_no_buffers( ) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) + result.update( + tensordict_save.select(*result.keys(True, True), strict=False).clone() + ) result[partial_steps] = out return result return out @@ -1543,7 +1550,7 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1661,6 +1668,14 @@ def step_and_maybe_reset( if partial_steps is not None: result = tensordict.new_zeros(tensordict_save.shape) result_ = tensordict_.new_zeros(tensordict_save.shape) + + result.update( + tensordict_save.select(*result.keys(True, True), strict=False).clone() + ) + result_.update( + tensordict_save.select(*result_.keys(True, True), strict=False).clone() + ) + result[partial_steps] = tensordict result_[partial_steps] = tensordict_ return result, result_ @@ -1700,7 +1715,7 @@ def _wait_for_workers(self, workers_range): def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1727,6 +1742,9 @@ def _step_no_buffers( out = out.to(self.device, non_blocking=self.non_blocking) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) + result.update( + tensordict_save.select(*result.keys(True, True), strict=False).clone() + ) result[partial_steps] = out return result return out @@ -1744,7 +1762,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - partial_steps = tensordict.get("_step", None) + partial_steps = tensordict.get("_step") tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1875,6 +1893,9 @@ def select_and_clone(name, tensor): self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) + result.update( + tensordict_save.select(*result.keys(True, True), strict=False).clone() + ) result[partial_steps] = out return result return out diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 14be04ef985..49cfe59d4f3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1933,6 +1933,17 @@ def state_spec_unbatched(self, spec: Composite): spec = spec.expand(self.batch_size + spec.shape) self.state_spec = spec + def _skip_tensordict(self, tensordict): + # Creates a "skip" tensordict, ie a placeholder for when a step is skipped + next_tensordict = self.full_done_spec.zero() + next_tensordict.update(self.full_observation_spec.zero()) + next_tensordict.update(self.full_reward_spec.zero()) + # Copy the data from tensordict in `next` + next_tensordict.update( + tensordict.select(*next_tensordict.keys(True, True), strict=False) + ) + return next_tensordict + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment. @@ -1953,25 +1964,33 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) - partial_steps = None + partial_steps = tensordict.pop("_step", None) - if not self.batch_locked: - # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here - partial_steps = tensordict.get("_step", None) - if partial_steps is not None: + next_tensordict = None + + if partial_steps is not None: + if not self.batch_locked: + # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here if partial_steps.all(): partial_steps = None else: tensordict_batch_size = tensordict.batch_size partial_steps = partial_steps.view(tensordict_batch_size) tensordict = tensordict[partial_steps] - else: + else: + if not partial_steps.any(): + next_tensordict = self._skip_tensordic(tensordict) + else: + # trust that the _step can handle this! + tensordict.set("_step", partial_steps) + tensordict_batch_size = self.batch_size next_preset = tensordict.get("next", None) - next_tensordict = self._step(tensordict) - next_tensordict = self._step_proc_data(next_tensordict) + if next_tensordict is None: + next_tensordict = self._step(tensordict) + next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: # tensordict could already have a "next" key # this could be done more efficiently by not excluding but just passing @@ -1980,7 +1999,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: next_preset.exclude(*next_tensordict.keys(True, True)) ) tensordict.set("next", next_tensordict) - if partial_steps is not None: + if partial_steps is not None and tensordict_batch_size != self.batch_size: result = tensordict.new_zeros(tensordict_batch_size) result[partial_steps] = tensordict return result diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 7ee142fe811..12d5d03508f 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalSkip, Crop, DeviceCastTransform, DiscreteActionProjection, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 883c3030ba5..395299d30f5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -693,10 +693,19 @@ class TransformedEnv(EnvBase, metaclass=_TEnvPostInit): in which case this value should be set to `False`. Default is `True`. + Keyword Args: + auto_unwrap (bool, optional): if ``True``, wrapping a transformed env in transformed env + unwraps the transforms of the inner TransformedEnv in the outer one (the new instance). + Defaults to ``True`` + Examples: >>> env = GymEnv("Pendulum-v0") >>> transform = RewardScaling(0.0, 1.0) >>> transformed_env = TransformedEnv(env, transform) + >>> # check auto-unwrap + >>> transformed_env = TransformedEnv(transformed_env, StepCounter()) + >>> # The inner env has been unwrapped + >>> assert isinstance(transformed_env.base_env, GymEnv) """ @@ -705,6 +714,8 @@ def __init__( env: EnvBase, transform: Optional[Transform] = None, cache_specs: bool = True, + *, + auto_unwrap: bool = True, **kwargs, ): self._transform = None @@ -717,7 +728,7 @@ def __init__( # Type matching must be exact here, because subtyping could introduce differences in behavior that must # be contained within the subclass. - if type(env) is TransformedEnv and type(self) is TransformedEnv: + if type(env) is TransformedEnv and type(self) is TransformedEnv and auto_unwrap: self._set_env(env.base_env, device) if type(transform) is not Compose: # we don't use isinstance as some transforms may be subclassed from @@ -923,7 +934,36 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # tensordict = tensordict.clone(False) next_preset = tensordict.get("next", None) tensordict_in = self.transform.inv(tensordict) - next_tensordict = self.base_env._step(tensordict_in) + + # It could be that the step must be skipped + partial_steps = tensordict_in.pop("_step", None) + next_tensordict = None + tensordict_batch_size = None + if partial_steps is not None: + if not self.batch_locked: + # Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here + if partial_steps.all(): + partial_steps = None + else: + tensordict_batch_size = tensordict_in.batch_size + partial_steps = partial_steps.view(tensordict_batch_size) + tensordict_in = tensordict_in[partial_steps] + else: + if not partial_steps.any(): + next_tensordict = self._skip_tensordict(tensordict_in) + elif not partial_steps.all(): + # trust that the _step can handle this! + tensordict_in.set("_step", partial_steps) + tensordict_batch_size = self.batch_size + + if next_tensordict is None: + next_tensordict = self.base_env._step(tensordict_in) + + if partial_steps is not None and tensordict_batch_size != self.batch_size: + result = next_tensordict.new_zeros(tensordict_batch_size) + result[partial_steps] = next_tensordict + next_tensordict = result + if next_preset is not None: # tensordict could already have a "next" key # this could be done more efficiently by not excluding but just passing @@ -1410,7 +1450,7 @@ def _rebuild_up_to(self, final_transform): # returns None if there is no parent env return None, None elif isinstance(container, TransformedEnv): - out = TransformedEnv(container.base_env) + out = TransformedEnv(container.base_env, auto_unwrap=False) elif container is None: # returns None if there is no parent env return None, None @@ -10213,3 +10253,87 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase: ) return (self.weights * reward).sum(dim=-1) + + +class ConditionalSkip(Transform): + """A transform that skips steps in the env if certain conditions are met. + + This transform writes the result of `cond(tensordict)` in the `"_step"` entry of the + tensordict passed as input to the `TransformedEnv.base_env._step` method. + If the `base_env` is not batch-locked (generally speaking, it is stateless), the tensordict is + reduced to its element that need to go through the step. + If it is batch-locked (generally speaking, it is stateful), the step is skipped altogether if no + value in `"_step"` is ``True``. Otherwise, it is trusted that the environment will account for the + `"_step"` signal accordingly. + + Args: + cond (Callable[[TensorDictBase], bool | torch.Tensor]): a callable for the tensordict input + that checks whether the next env step must be skipped (`True` = skipped, `False` = execute + env.step). + + Examples: + >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv + >>> from torchrl.envs import GymEnv + >>> import torch + >>> + >>> torch.manual_seed(0) + >>> + >>> base_env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(step_count_key="other_count")) + >>> env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) + >>> env = env.append_transform(ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1)) + >>> env.set_seed(0) + >>> + >>> r = env.rollout(10) + >>> print(r["observation"]) + tensor([[-0.9670, -0.2546, -0.9669], + [-0.9802, -0.1981, -1.1601], + [-0.9802, -0.1981, -1.1601], + [-0.9926, -0.1214, -1.5556], + [-0.9926, -0.1214, -1.5556], + [-0.9994, -0.0335, -1.7622], + [-0.9994, -0.0335, -1.7622], + [-0.9984, 0.0561, -1.7933], + [-0.9984, 0.0561, -1.7933], + [-0.9895, 0.1445, -1.7779]]) + >>> print(r["step_count"]) + tensor([[0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9]]) + >>> print(r["other_count"]) + tensor([[0], + [1], + [1], + [2], + [2], + [3], + [3], + [4], + [4], + [5]]) + + """ + + def __init__(self, cond: Callable[[TensorDict], bool | torch.Tensor]): + super().__init__() + self.cond = cond + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + # Run cond + cond = self.cond(tensordict) + # Write result in step + tensordict["_step"] = tensordict.get("_step", True) & ~cond + if not tensordict["_step"].shape == tensordict.batch_size: + tensordict["_step"] = tensordict["_step"].view(tensordict.batch_size) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise NotImplementedError( + FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__) + ) From 6970cb632563fb0a565c97ebd1bd82a7e107777f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 17:39:53 +0000 Subject: [PATCH 02/14] Update [ghstack-poisoned] --- test/test_env.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 3e4bb0febca..b9672c0229c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4119,9 +4119,9 @@ def test_parallel_partial_steps( td.set("action", penv.full_action_spec[penv.action_key].one()) td = penv.step(td) - assert (td[0].get("next") == 0).all() + assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) assert (td[3].get("next") != 0).any() @pytest.mark.parametrize("use_buffers", [False, True]) @@ -4142,9 +4142,9 @@ def test_parallel_partial_step_and_maybe_reset( td.set("action", penv.full_action_spec[penv.action_key].one()) td, tdreset = penv.step_and_maybe_reset(td) - assert (td[0].get("next") == 0).all() + assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) assert (td[3].get("next") != 0).any() @pytest.mark.parametrize("use_buffers", [False, True]) @@ -4163,9 +4163,9 @@ def test_serial_partial_steps(self, use_buffers, device, env_device): td.set("action", penv.full_action_spec[penv.action_key].one()) td = penv.step(td) - assert (td[0].get("next") == 0).all() + assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) assert (td[3].get("next") != 0).any() @pytest.mark.parametrize("use_buffers", [False, True]) @@ -4184,9 +4184,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi td.set("action", penv.full_action_spec[penv.action_key].one()) td = penv.step(td) - assert (td[0].get("next") == 0).all() + assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() - assert (td[2].get("next") == 0).all() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) assert (td[3].get("next") != 0).any() From 7e600882e51291004172671e162b5130cf3c92e1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 18:59:14 +0000 Subject: [PATCH 03/14] Update [ghstack-poisoned] --- .github/workflows/test-linux.yml | 32 ------------- docs/source/reference/envs.rst | 1 + test/test_env.py | 78 ++++++++++++++++++-------------- torchrl/envs/batched_envs.py | 6 ++- 4 files changed, 50 insertions(+), 67 deletions(-) diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 1565e49707e..7b3c6b9dcb0 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -53,38 +53,6 @@ jobs: ## setup_env.sh bash .github/unittest/linux/scripts/run_all.sh - tests-cpu-oldget: - # Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7 - strategy: - matrix: - python_version: ["3.12"] - fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - with: - runner: linux.12xlarge - repository: pytorch/rl - docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04" - timeout: 90 - script: | - if [[ "${{ github.ref }}" =~ release/* ]]; then - export RELEASE=1 - export TORCH_VERSION=stable - else - export RELEASE=0 - export TORCH_VERSION=nightly - fi - export TD_GET_DEFAULTS_TO_NONE=0 - - # Set env vars from matrix - export PYTHON_VERSION=${{ matrix.python_version }} - export CU_VERSION="cpu" - - echo "PYTHON_VERSION: $PYTHON_VERSION" - echo "CU_VERSION: $CU_VERSION" - - ## setup_env.sh - bash .github/unittest/linux/scripts/run_all.sh - tests-gpu: strategy: matrix: diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 8b64b87f8bd..d37664a7c4d 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -885,6 +885,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + ConditionalSkip Crop DTypeCastTransform DeviceCastTransform diff --git a/test/test_env.py b/test/test_env.py index b9672c0229c..614141e5af9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4112,17 +4112,21 @@ def test_parallel_partial_steps( use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td = penv.step(td) - assert_allclose_td(td[0].get("next"), td[0], intersection=True) - assert (td[1].get("next") != 0).any() - assert_allclose_td(td[2].get("next"), td[2], intersection=True) - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td = penv.step(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_parallel_partial_step_and_maybe_reset( @@ -4135,17 +4139,21 @@ def test_parallel_partial_step_and_maybe_reset( use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td, tdreset = penv.step_and_maybe_reset(td) - assert_allclose_td(td[0].get("next"), td[0], intersection=True) - assert (td[1].get("next") != 0).any() - assert_allclose_td(td[2].get("next"), td[2], intersection=True) - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td, tdreset = penv.step_and_maybe_reset(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_steps(self, use_buffers, device, env_device): @@ -4156,17 +4164,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device): use_buffers=use_buffers, device=device, ) - td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) - psteps[[1, 3]] = True - td.set("_step", psteps) - - td.set("action", penv.full_action_spec[penv.action_key].one()) - td = penv.step(td) - assert_allclose_td(td[0].get("next"), td[0], intersection=True) - assert (td[1].get("next") != 0).any() - assert_allclose_td(td[2].get("next"), td[2], intersection=True) - assert (td[3].get("next") != 0).any() + try: + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.full_action_spec[penv.action_key].one()) + td = penv.step(td) + assert_allclose_td(td[0].get("next"), td[0], intersection=True) + assert (td[1].get("next") != 0).any() + assert_allclose_td(td[2].get("next"), td[2], intersection=True) + assert (td[3].get("next") != 0).any() + finally: + penv.close() + del penv @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 47163928ca0..8f7531a86d0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1669,9 +1669,11 @@ def step_and_maybe_reset( result = tensordict.new_zeros(tensordict_save.shape) result_ = tensordict_.new_zeros(tensordict_save.shape) - result.update( - tensordict_save.select(*result.keys(True, True), strict=False).clone() + old_r_copy = tensordict_save.clone().set( + "next", + tensordict_save.select(*next_td.keys(True, True), strict=False).clone(), ) + result.update(old_r_copy) result_.update( tensordict_save.select(*result_.keys(True, True), strict=False).clone() ) From 388c2dfade2f695ed8891ae3dc362174ebd77920 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 21:46:56 +0000 Subject: [PATCH 04/14] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8f7531a86d0..10b0020e82c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1669,14 +1669,30 @@ def step_and_maybe_reset( result = tensordict.new_zeros(tensordict_save.shape) result_ = tensordict_.new_zeros(tensordict_save.shape) - old_r_copy = tensordict_save.clone().set( + def select_and_transfer(x, y): + if y is not None: + return ( + x.to(y.device, non_blocking=self.non_blocking) + if x.device != y.device + else x.clone() + ) + + old_r_copy = tensordict_save._fast_apply( + select_and_transfer, result, filter_empty=True, device=device + ) + old_r_copy.set( "next", - tensordict_save.select(*next_td.keys(True, True), strict=False).clone(), + tensordict_save._fast_apply( + select_and_transfer, next_td, filter_empty=True, device=device + ), ) result.update(old_r_copy) result_.update( - tensordict_save.select(*result_.keys(True, True), strict=False).clone() + tensordict_save._fast_apply( + select_and_transfer, result_, filter_empty=True, device=device + ) ) + self._sync_w2m() result[partial_steps] = tensordict result_[partial_steps] = tensordict_ From 16d647bad3d4bc3293b72cf8c1ccb35cb6346fb2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 21:52:29 +0000 Subject: [PATCH 05/14] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 10b0020e82c..c9ca1749460 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1678,18 +1678,30 @@ def select_and_transfer(x, y): ) old_r_copy = tensordict_save._fast_apply( - select_and_transfer, result, filter_empty=True, device=device + select_and_transfer, + result, + filter_empty=True, + device=device, + default=None, ) old_r_copy.set( "next", tensordict_save._fast_apply( - select_and_transfer, next_td, filter_empty=True, device=device + select_and_transfer, + next_td, + filter_empty=True, + device=device, + default=None, ), ) result.update(old_r_copy) result_.update( tensordict_save._fast_apply( - select_and_transfer, result_, filter_empty=True, device=device + select_and_transfer, + result_, + filter_empty=True, + device=device, + default=None, ) ) self._sync_w2m() From 44ec09b2b64c83c3f93e9320c25778a953961561 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Feb 2025 12:43:14 +0000 Subject: [PATCH 06/14] Update [ghstack-poisoned] --- test/test_transforms.py | 41 ++++++++------ torchrl/envs/transforms/transforms.py | 78 +++++++++++++++++++-------- 2 files changed, 80 insertions(+), 39 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5d7fa5d325d..d0f6f7f20a8 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -13456,11 +13456,14 @@ class TestConditionalSkip(TransformBase): @pytest.mark.parametrize("bwad", [False, True]) def test_single_trans_env_check(self, bwad): env = CountingEnv() - base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) - env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) - env = env.append_transform( - ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1) + base_env = TransformedEnv( + env, + Compose( + StepCounter(step_count_key="other_count"), + ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1), + ), ) + env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) env.set_seed(0) env.check_env_specs() policy = lambda td: td.set("action", torch.ones((1,))) @@ -13472,13 +13475,16 @@ def test_single_trans_env_check(self, bwad): def test_serial_trans_env_check(self, bwad): def make_env(i): env = CountingEnv() - base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) - return TransformedEnv( - base_env, + base_env = TransformedEnv( + env, Compose( - StepCounter(), + StepCounter(step_count_key="other_count"), ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), ), + ) + return TransformedEnv( + base_env, + StepCounter(), auto_unwrap=False, ) @@ -13494,13 +13500,16 @@ def make_env(i): def test_parallel_trans_env_check(self, bwad): def make_env(i): env = CountingEnv() - base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) - return TransformedEnv( - base_env, + base_env = TransformedEnv( + env, Compose( - StepCounter(), + StepCounter(step_count_key="other_count"), ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), ), + ) + return TransformedEnv( + base_env, + StepCounter(), auto_unwrap=False, ) @@ -13533,7 +13542,8 @@ def cond(td): sc = td["step_count"] + torch.tensor([[0], [1]]) return sc.squeeze() % 2 == 0 - env = TransformedEnv(base_env, Compose(StepCounter(), ConditionalSkip(cond))) + env = TransformedEnv(base_env, ConditionalSkip(cond)) + env = TransformedEnv(env, StepCounter(), auto_unwrap=False) env.check_env_specs() policy = lambda td: td.set("action", torch.ones((2, 1))) r = env.rollout(10, policy, break_when_any_done=bwad) @@ -13558,9 +13568,8 @@ def cond(td): sc = td["step_count"] + torch.tensor([[0], [1]]) return sc.squeeze() % 2 == 0 - env = TransformedEnv( - base_env, Compose(StepCounter(), ConditionalSkip(cond)) - ) + env = TransformedEnv(base_env, ConditionalSkip(cond)) + env = TransformedEnv(env, StepCounter(), auto_unwrap=False) env.check_env_specs() policy = lambda td: td.set("action", torch.ones((2, 1))) r = env.rollout(10, policy, break_when_any_done=bwad) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 395299d30f5..a361bc17953 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -635,7 +635,7 @@ def parent(self) -> Optional[EnvBase]: ) parent, _ = container._rebuild_up_to(self) elif isinstance(container, TransformedEnv): - parent = TransformedEnv(container.base_env) + parent = TransformedEnv(container.base_env, auto_unwrap=False) else: raise ValueError(f"container is of type {type(container)}") self.__dict__["_parent"] = parent @@ -958,22 +958,22 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if next_tensordict is None: next_tensordict = self.base_env._step(tensordict_in) + if next_preset is not None: + # tensordict could already have a "next" key + # this could be done more efficiently by not excluding but just passing + # the necessary keys + next_tensordict.update( + next_preset.exclude(*next_tensordict.keys(True, True)) + ) + self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) + # we want the input entries to remain unchanged + next_tensordict = self.transform._step(tensordict, next_tensordict) if partial_steps is not None and tensordict_batch_size != self.batch_size: result = next_tensordict.new_zeros(tensordict_batch_size) result[partial_steps] = next_tensordict next_tensordict = result - if next_preset is not None: - # tensordict could already have a "next" key - # this could be done more efficiently by not excluding but just passing - # the necessary keys - next_tensordict.update( - next_preset.exclude(*next_tensordict.keys(True, True)) - ) - self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) - # we want the input entries to remain unchanged - next_tensordict = self.transform._step(tensordict, next_tensordict) return next_tensordict def set_seed( @@ -9079,6 +9079,7 @@ class _CallableTransform(Transform): # A wrapper around a custom callable to make it possible to transform any data type def __init__(self, func): super().__init__() + raise RuntimeError(isinstance(func, Transform), func) self.func = func def forward(self, *args, **kwargs): @@ -10266,21 +10267,40 @@ class ConditionalSkip(Transform): value in `"_step"` is ``True``. Otherwise, it is trusted that the environment will account for the `"_step"` signal accordingly. + .. note:: The skip will affect transforms that modify the environment output too, i.e., any transform + that is to be exectued on the tensordict returned by :meth:`~torchrl.envs.EnvBase.step` will be + skipped if the condition is met. To palliate this effect if it is not desirable, one can wrap + the transformed env in another transformed env, since the skip only affects the first-degree parent + of the ``ConditionalSkip`` transform. See example below. + Args: cond (Callable[[TensorDictBase], bool | torch.Tensor]): a callable for the tensordict input that checks whether the next env step must be skipped (`True` = skipped, `False` = execute env.step). Examples: - >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv - >>> from torchrl.envs import GymEnv >>> import torch >>> + >>> from torchrl.envs import GymEnv + >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose + >>> >>> torch.manual_seed(0) >>> - >>> base_env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(step_count_key="other_count")) - >>> env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) - >>> env = env.append_transform(ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1)) + >>> base_env = TransformedEnv( + ... GymEnv("Pendulum-v1"), + ... StepCounter(step_count_key="inner_count"), + ... ) + >>> middle_env = TransformedEnv( + ... base_env, + ... Compose( + ... StepCounter(step_count_key="middle_count"), + ... ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1), + ... ), + ... auto_unwrap=False) # makes sure that transformed envs are properly wrapped + >>> env = TransformedEnv( + ... middle_env, + ... StepCounter(step_count_key="step_count"), + ... auto_unwrap=False) >>> env.set_seed(0) >>> >>> r = env.rollout(10) @@ -10295,18 +10315,18 @@ class ConditionalSkip(Transform): [-0.9984, 0.0561, -1.7933], [-0.9984, 0.0561, -1.7933], [-0.9895, 0.1445, -1.7779]]) - >>> print(r["step_count"]) + >>> print(r["inner_count"]) tensor([[0], [1], + [1], + [2], [2], [3], + [3], [4], - [5], - [6], - [7], - [8], - [9]]) - >>> print(r["other_count"]) + [4], + [5]]) + >>> print(r["middle_count"]) tensor([[0], [1], [1], @@ -10317,6 +10337,18 @@ class ConditionalSkip(Transform): [4], [4], [5]]) + >>> print(r["step_count"]) + tensor([[0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [8], + [9]]) + """ From 786cb23f41cf155b9b2ae5ea0c8ef570081db3bc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Feb 2025 15:35:22 +0000 Subject: [PATCH 07/14] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 76 +++++++++++++++ test/test_transforms.py | 60 +++++++++++- torchrl/envs/batched_envs.py | 129 +++++++++++++++++++++----- torchrl/envs/common.py | 46 ++++++++- torchrl/envs/transforms/transforms.py | 29 +++++- 5 files changed, 307 insertions(+), 33 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index d37664a7c4d..589130955e3 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch. GymLikeEnv EnvMetaData +Partial steps and partial resets +-------------------------------- + +TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments. +If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed +below. + +Batching environments and locking the batch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. _ref_batch_locked: + +Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has +a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an +input of arbitrary size, batches the operations over all elements (mostly stateless environments). + +This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input +tensordicts to have the same batch-size as the env's. Typical examples of these environments are +:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size. +Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`. + +Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the +tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous +input. + +Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with +partial steps easily, they just pass the actions to the sub-environments that are required to be executed. + +In all other cases, TorchRL assumes that the environment handles the partial steps correctly. + +.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl + to control what happens within the `_step` method! + +Partial Steps +~~~~~~~~~~~~~ + +.. _ref_partial_steps: + +Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the +size of the tensordict that holds it. The classes armed to deal with this are: + +- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the + action to and only to the environments where `"_step"` is `True`; +- Batch-unlocked environments; +- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step` + method will first look for a `"_step"` entry and, if present, act accordingly. + If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by + :class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further + transformation. + +When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous +content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that +if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for +all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that +is not observed because these classes handle the passing of data properly. + +Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`, +as the environments with a `True` done state will need to be skipped during calls to `_step`. + +The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips. + +Partial Resets +~~~~~~~~~~~~~~ + +.. _ref_partial_resets: + +Partial resets work pretty much like partial steps, but with the `"_reset"` entry. + +The same restrictions of partial steps apply to partial resets. + +Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`, +as the environments with a `True` done state will need to be reset, but not others. + +See te following paragraph for a deep dive in partial resets within batched and vectorized environments. + Vectorized envs --------------- @@ -212,6 +287,7 @@ component (sub-environments or agents) should be reset. This allows to reset some but not all of the components. The ``"_reset"`` key has two distinct functionalities: + 1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may not be present in the input tensordict. TorchRL's convention is that the absence of the ``"_reset"`` key at a given ``"done"`` level indicates diff --git a/test/test_transforms.py b/test/test_transforms.py index d0f6f7f20a8..de59fee8c4b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import abc import argparse +import collections import contextlib import importlib.util @@ -13453,9 +13454,53 @@ def test_composite_reward_spec(self) -> None: class TestConditionalSkip(TransformBase): + def check_non_tensor_match(self, td): + q = collections.deque() + obs_str = td["obs_str"] + obs = td["observation"] + q.extend(list(zip(obs_str, obs.unbind(0)))) + next_obs_str = td["next", "obs_str"] + next_obs = td["next", "observation"] + q.extend(zip(next_obs_str, next_obs.unbind(0))) + while len(q): + o_str, o = q.popleft() + if isinstance(o_str, list): + q.extend(zip(o_str, o.unbind(0))) + else: + assert o_str == str(o), (obs, obs_str, next_obs, next_obs_str) + + class ToString(Transform): + def _apply_transform(self, obs: torch.Tensor) -> None: + return NonTensorData(str(obs), device=obs.device) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + return self._call(tensordict_reset) + + def transform_observation_spec( + self, observation_spec: TensorSpec + ) -> TensorSpec: + observation_spec["obs_str"] = NonTensor( + example_data="a string!", + shape=observation_spec.shape, + device=observation_spec.device, + ) + return observation_spec + + class CountinEnvWithString(TransformedEnv): + def __init__(self, *args, **kwargs): + base_env = CountingEnv() + super().__init__( + base_env, + TestConditionalSkip.ToString( + in_keys=["observation"], out_keys=["obs_str"] + ), + ) + @pytest.mark.parametrize("bwad", [False, True]) def test_single_trans_env_check(self, bwad): - env = CountingEnv() + env = TestConditionalSkip.CountinEnvWithString() base_env = TransformedEnv( env, Compose( @@ -13470,11 +13515,12 @@ def test_single_trans_env_check(self, bwad): r = env.rollout(10, policy, break_when_any_done=bwad) assert (r["step_count"] == torch.arange(10).view(10, 1)).all() assert (r["other_count"] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) @pytest.mark.parametrize("bwad", [False, True]) def test_serial_trans_env_check(self, bwad): def make_env(i): - env = CountingEnv() + env = TestConditionalSkip.CountinEnvWithString() base_env = TransformedEnv( env, Compose( @@ -13495,11 +13541,12 @@ def make_env(i): assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) @pytest.mark.parametrize("bwad", [False, True]) def test_parallel_trans_env_check(self, bwad): def make_env(i): - env = CountingEnv() + env = TestConditionalSkip.CountinEnvWithString() base_env = TransformedEnv( env, Compose( @@ -13525,6 +13572,7 @@ def make_env(i): ).all() assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) finally: env.close() del env @@ -13532,7 +13580,7 @@ def make_env(i): @pytest.mark.parametrize("bwad", [False, True]) def test_trans_serial_env_check(self, bwad): def make_env(): - env = CountingEnv(max_steps=100) + env = TestConditionalSkip.CountinEnvWithString(max_steps=100) base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) return base_env @@ -13550,12 +13598,13 @@ def cond(td): assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) @pytest.mark.parametrize("bwad", [True, False]) @pytest.mark.parametrize("buffers", [True, False]) def test_trans_parallel_env_check(self, bwad, buffers): def make_env(): - env = CountingEnv(max_steps=100) + env = TestConditionalSkip.CountinEnvWithString(max_steps=100) base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) return base_env @@ -13578,6 +13627,7 @@ def cond(td): ).all() assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() + self.check_non_tensor_match(r) finally: base_env.close() del base_env diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c9ca1749460..b060eb1b096 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -29,6 +29,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import _is_leaf_nontensor from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( @@ -1162,13 +1163,32 @@ def select_and_clone(name, tensor): out_tds.append(out_td) out = LazyStackedTensorDict.maybe_dense_stack(out_tds) - if partial_steps is not None: + if partial_steps is not None and not partial_steps.all(): result = out.new_zeros(tensordict_save.shape) # Copy the observation data from the previous step as placeholder - result.update( - tensordict_save.select(*result.keys(True, True), strict=False).clone() + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, ) - result[partial_steps] = out + + result.update(prev) + if partial_steps.any(): + result[partial_steps] = out + assert result[partial_steps]["obs_str"] == out["obs_str"] return result return out @@ -1533,10 +1553,29 @@ def _step_and_maybe_reset_no_buffers( ) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result.update( - tensordict_save.select(*result.keys(True, True), strict=False).clone() + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, ) - result[partial_steps] = out + + result.update(prev) + + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -1615,10 +1654,11 @@ def step_and_maybe_reset( data = [{} for _ in workers_range] if self._non_tensor_keys: - for i in workers_range: - data[i]["non_tensor_data"] = tensordict[i].select( - *self._non_tensor_keys, strict=False - ) + for i, td in zip( + workers_range, + tensordict.select(*self._non_tensor_keys, strict=False).unbind(0), + ): + data[i]["non_tensor_data"] = td self._sync_m2w() for i, _data in zip(workers_range, data): @@ -1706,8 +1746,9 @@ def select_and_transfer(x, y): ) self._sync_w2m() - result[partial_steps] = tensordict - result_[partial_steps] = tensordict_ + if partial_steps.any(): + result[partial_steps] = tensordict + result_[partial_steps] = tensordict_ return result, result_ return tensordict, tensordict_ @@ -1772,10 +1813,29 @@ def _step_no_buffers( out = out.to(self.device, non_blocking=self.non_blocking) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result.update( - tensordict_save.select(*result.keys(True, True), strict=False).clone() + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, ) - result[partial_steps] = out + + result.update(prev) + + if partial_steps.any(): + result[partial_steps] = out return result return out @@ -1822,7 +1882,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: shared_tensordict_parent.update_( tensordict, - keys_to_update=list(self._env_input_keys), + # We also update the output keys because they can be implicitly used, eg + # during partial steps to fill in values + keys_to_update=list(self._env_input_keys) + list(self._env_output_keys), non_blocking=self.non_blocking, ) next_td_passthrough = tensordict.get("next", None) @@ -1868,10 +1930,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: - for i in workers_range: - data[i]["non_tensor_data"] = tensordict[i].select( - *self._non_tensor_keys, strict=False - ) + for i, td in zip( + workers_range, + tensordict.select(*self._non_tensor_keys, strict=False).unbind(0), + ): + data[i]["non_tensor_data"] = td self._sync_m2w() @@ -1923,10 +1986,28 @@ def select_and_clone(name, tensor): self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) - result.update( - tensordict_save.select(*result.keys(True, True), strict=False).clone() + + def select_and_clone(x, y): + if y is not None: + if x.device != y.device: + x = x.to(y.device) + else: + x = x.clone() + return x + + prev = tensordict_save._fast_apply( + select_and_clone, + result, + filter_empty=True, + device=result.device, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + default=None, ) - result[partial_steps] = out + + result.update(prev) + if partial_steps.any(): + result[partial_steps] = out return result return out diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 49cfe59d4f3..e8b41ae977d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1938,9 +1938,24 @@ def _skip_tensordict(self, tensordict): next_tensordict = self.full_done_spec.zero() next_tensordict.update(self.full_observation_spec.zero()) next_tensordict.update(self.full_reward_spec.zero()) + # Copy the data from tensordict in `next` + def select_and_clone(x, y): + if y is not None: + if y.device == x.device: + return x.clone() + return x.to(y.device) + next_tensordict.update( - tensordict.select(*next_tensordict.keys(True, True), strict=False) + tensordict._fast_apply( + select_and_clone, + next_tensordict, + device=next_tensordict.device, + batch_size=next_tensordict.batch_size, + default=None, + filter_empty=True, + is_leaf=_is_leaf_nontensor, + ) ) return next_tensordict @@ -2001,7 +2016,26 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set("next", next_tensordict) if partial_steps is not None and tensordict_batch_size != self.batch_size: result = tensordict.new_zeros(tensordict_batch_size) - result[partial_steps] = tensordict + + def select_and_clone(x, y): + if y is not None: + if x.device == y.device: + return x.clone() + return x.to(y.device) + + result.update( + tensordict._fast_apply( + select_and_clone, + result, + device=result.device, + filter_empty=True, + default=None, + batch_size=result.batch_size, + is_leaf=_is_leaf_nontensor, + ) + ) + if partial_steps.any(): + result[partial_steps] = tensordict return result return tensordict @@ -2892,9 +2926,17 @@ def rollout( policy device before the policy is used. Default is ``False``. break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the done states. If ``False``, then the done environments are reset automatically. Default is ``True``. + + .. seealso:: The :ref:`Partial resets ` of the documentation gives more + information about partial resets. + break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any of the done states. If ``False``, break if at least one environment reaches any of the done states. Default is ``False``. + + .. seealso:: The :ref:`Partial steps ` of the documentation gives more + information about partial resets. + return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if the env does not have dynamic specs, otherwise `False`. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a361bc17953..28f23a839ba 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -44,6 +44,7 @@ unravel_key, unravel_key_list, ) +from tensordict.base import _is_leaf_nontensor from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( _unravel_key_to_tuple, @@ -947,13 +948,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_batch_size = tensordict_in.batch_size partial_steps = partial_steps.view(tensordict_batch_size) + tensordict_save = tensordict_in[~partial_steps] tensordict_in = tensordict_in[partial_steps] else: if not partial_steps.any(): next_tensordict = self._skip_tensordict(tensordict_in) + # No need to copy anything + partial_steps = None elif not partial_steps.all(): # trust that the _step can handle this! tensordict_in.set("_step", partial_steps) + # The filling should be handled by the sub-env + partial_steps = None + else: + partial_steps = None tensordict_batch_size = self.batch_size if next_tensordict is None: @@ -969,9 +977,26 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # we want the input entries to remain unchanged next_tensordict = self.transform._step(tensordict, next_tensordict) - if partial_steps is not None and tensordict_batch_size != self.batch_size: + if partial_steps is not None: result = next_tensordict.new_zeros(tensordict_batch_size) - result[partial_steps] = next_tensordict + + def select_and_clone(x, y): + if y is not None: + if x.device == y.device: + return x.clone() + return x.to(y.device) + + if not partial_steps.all(): + result[~partial_steps] = tensordict_save._fast_apply( + select_and_clone, + result, + device=result.device, + filter_empty=True, + default=None, + is_leaf=_is_leaf_nontensor, + ) + if partial_steps.any(): + result[partial_steps] = next_tensordict next_tensordict = result return next_tensordict From f65968089495c5c61e24ce5d6db111c93d33ad59 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Feb 2025 16:06:57 +0000 Subject: [PATCH 08/14] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b060eb1b096..77c86ac9013 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1188,7 +1188,6 @@ def select_and_clone(x, y): result.update(prev) if partial_steps.any(): result[partial_steps] = out - assert result[partial_steps]["obs_str"] == out["obs_str"] return result return out From fb06a42d9943f01b2d72deee54bf77478775d41c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Feb 2025 17:21:07 +0000 Subject: [PATCH 09/14] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 77c86ac9013..f916ab11acb 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1875,6 +1875,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict[partial_steps].to( self.shared_tensordict_parent.device ) + shared_tensordict_parent = shared_tensordict_parent[partial_steps] else: workers_range = range(self.num_workers) shared_tensordict_parent = self.shared_tensordict_parent @@ -1883,7 +1884,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict, # We also update the output keys because they can be implicitly used, eg # during partial steps to fill in values - keys_to_update=list(self._env_input_keys) + list(self._env_output_keys), + keys_to_update=list(self._env_input_keys), non_blocking=self.non_blocking, ) next_td_passthrough = tensordict.get("next", None) @@ -2415,13 +2416,18 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if non_tensor_data is not None: input.update(non_tensor_data) - next_td = env._step(input) + input = env.step(input.copy()) + next_td = input.get("next") next_shared_tensordict.update_(next_td, non_blocking=non_blocking) + if event is not None: event.record() event.synchronize() mp_event.set() + # Make sure the root is updated + root_shared_tensordict.update_(env._step_mdp(input)) + if _non_tensor_keys: child_pipe.send( ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) From 239616f6e46f647a41d203ce3c127ce0f6b72455 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Feb 2025 17:51:24 +0000 Subject: [PATCH 10/14] Update [ghstack-poisoned] --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 28f23a839ba..ae5a63c3e59 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -948,7 +948,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_batch_size = tensordict_in.batch_size partial_steps = partial_steps.view(tensordict_batch_size) - tensordict_save = tensordict_in[~partial_steps] + tensordict_in_save = tensordict_in[~partial_steps] tensordict_in = tensordict_in[partial_steps] else: if not partial_steps.any(): @@ -987,7 +987,7 @@ def select_and_clone(x, y): return x.to(y.device) if not partial_steps.all(): - result[~partial_steps] = tensordict_save._fast_apply( + result[~partial_steps] = tensordict_in_save._fast_apply( select_and_clone, result, device=result.device, From 2aa759ab860e7f1eac249e0c3074cef8c7055596 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 12 Feb 2025 21:21:45 +0000 Subject: [PATCH 11/14] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 1 - torchrl/envs/transforms/transforms.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f916ab11acb..004e6e94a73 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1875,7 +1875,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict[partial_steps].to( self.shared_tensordict_parent.device ) - shared_tensordict_parent = shared_tensordict_parent[partial_steps] else: workers_range = range(self.num_workers) shared_tensordict_parent = self.shared_tensordict_parent diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ae5a63c3e59..d710035a27d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9104,7 +9104,6 @@ class _CallableTransform(Transform): # A wrapper around a custom callable to make it possible to transform any data type def __init__(self, func): super().__init__() - raise RuntimeError(isinstance(func, Transform), func) self.func = func def forward(self, *args, **kwargs): From 25970415db93d77667157dfda266ce39aec4d93b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 13 Feb 2025 15:58:43 +0000 Subject: [PATCH 12/14] Update [ghstack-poisoned] --- test/mocking_classes.py | 9 ++++++++- test/test_env.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index a1515390eb6..d933cb72216 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -538,6 +538,8 @@ def _step( class ContinuousActionVecMockEnv(_MockEnv): + adapt_dtype: bool = True + @classmethod def __new__( cls, @@ -636,7 +638,12 @@ def _step( done = done.any(-1) done = reward = done.unsqueeze(-1) tensordict.set( - "reward", reward.to(self.reward_spec.dtype).expand(self.reward_spec.shape) + "reward", + reward.to( + self.reward_spec.dtype + if self.adapt_dtype + else torch.get_default_dtype() + ).expand(self.reward_spec.shape), ) tensordict.set("done", done) tensordict.set("terminated", done) diff --git a/test/test_env.py b/test/test_env.py index 614141e5af9..71482675450 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -235,6 +235,7 @@ class TestEnvBase: def test_run_type_checks(self): env = ContinuousActionVecMockEnv() + env.adapt_dtype = False env._run_type_checks = False check_env_specs(env) env._run_type_checks = True From 6e7adf858bad76fd401e6a0f79df647324434e13 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 13 Feb 2025 16:09:30 +0000 Subject: [PATCH 13/14] Update [ghstack-poisoned] --- test/test_env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_env.py b/test/test_env.py index 71482675450..fd5e0521979 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4142,12 +4142,13 @@ def test_parallel_partial_step_and_maybe_reset( ) try: td = penv.reset() - psteps = torch.zeros(4, dtype=torch.bool) + psteps = torch.zeros(4, dtype=torch.bool, device=td.get("done").device) psteps[[1, 3]] = True td.set("_step", psteps) td.set("action", penv.full_action_spec[penv.action_key].one()) td, tdreset = penv.step_and_maybe_reset(td) + print(td) assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() assert_allclose_td(td[2].get("next"), td[2], intersection=True) From a2e1fc467043d487cc2d55338b9cff0aa6ba0a47 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 13 Feb 2025 16:31:18 +0000 Subject: [PATCH 14/14] Update [ghstack-poisoned] --- test/test_env.py | 1 - torchrl/envs/batched_envs.py | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_env.py b/test/test_env.py index fd5e0521979..ad02467d6ab 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4148,7 +4148,6 @@ def test_parallel_partial_step_and_maybe_reset( td.set("action", penv.full_action_spec[penv.action_key].one()) td, tdreset = penv.step_and_maybe_reset(td) - print(td) assert_allclose_td(td[0].get("next"), td[0], intersection=True) assert (td[1].get("next") != 0).any() assert_allclose_td(td[2].get("next"), td[2], intersection=True) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 81e140dc8eb..7bc9b0c1a5a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1691,6 +1691,14 @@ def step_and_maybe_reset( device=device, filter_empty=True, ) + if tensordict.device != device: + tensordict = tensordict._fast_apply( + lambda x: x.to(device, non_blocking=self.non_blocking) + if x.device != device + else x, + device=device, + filter_empty=True, + ) self._sync_w2m() else: next_td = next_td.clone().clear_device_()