Skip to content

Commit 6ae8d43

Browse files
committed
[BugFix] Fix behavior or partial, nested dones in PEnv and TEnv
ghstack-source-id: e36d1c8 Pull-Request-resolved: #2959
1 parent 36f34da commit 6ae8d43

File tree

4 files changed

+142
-12
lines changed

4 files changed

+142
-12
lines changed

test/mocking_classes.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensordict import tensorclass, TensorDict, TensorDictBase
1414
from tensordict.nn import TensorDictModuleBase
1515
from tensordict.utils import expand_right, NestedKey
16-
16+
from torchrl._utils import logger as torchrl_logger
1717
from torchrl.data import (
1818
Binary,
1919
Bounded,
@@ -2533,3 +2533,58 @@ def __next__(self):
25332533
else:
25342534
tokens = tensors
25352535
return {"tokens": tokens, "attention_mask": tokens != 0}
2536+
2537+
2538+
class MockNestedResetEnv(EnvBase):
2539+
"""To test behaviour of envs with nested done states - where the root done prevails over others."""
2540+
2541+
def __init__(self, num_steps: int, done_at_root: bool) -> None:
2542+
super().__init__(device="cpu")
2543+
self._num_steps = num_steps
2544+
self._counter = 0
2545+
self.done_at_root = done_at_root
2546+
self.done_spec = Composite(
2547+
{
2548+
("agent_1", "done"): Binary(1, dtype=torch.bool),
2549+
("agent_2", "done"): Binary(1, dtype=torch.bool),
2550+
}
2551+
)
2552+
if done_at_root:
2553+
self.full_done_spec["done"] = Binary(1, dtype=torch.bool)
2554+
2555+
def _reset(self, tensordict: TensorDict) -> TensorDict:
2556+
torchrl_logger.info(f"Reset after {self._counter} steps!")
2557+
if tensordict is not None:
2558+
torchrl_logger.info(f"tensordict at reset {tensordict.to_dict()}")
2559+
self._counter = 0
2560+
result = TensorDict(
2561+
{
2562+
("agent_1", "done"): torch.tensor([False], dtype=torch.bool),
2563+
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2564+
},
2565+
)
2566+
if self.done_at_root:
2567+
result["done"] = torch.tensor([False], dtype=torch.bool)
2568+
return result
2569+
2570+
def _step(self, tensordict: TensorDict) -> TensorDict:
2571+
self._counter += 1
2572+
done = torch.tensor([self._counter >= self._num_steps], dtype=torch.bool)
2573+
if self.done_at_root:
2574+
return TensorDict(
2575+
{
2576+
"done": done,
2577+
("agent_1", "done"): torch.tensor([True], dtype=torch.bool),
2578+
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2579+
},
2580+
)
2581+
else:
2582+
return TensorDict(
2583+
{
2584+
("agent_1", "done"): done,
2585+
("agent_2", "done"): torch.tensor([False], dtype=torch.bool),
2586+
},
2587+
)
2588+
2589+
def _set_seed(self):
2590+
pass

test/test_env.py

+59
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from torchrl.envs.transforms.transforms import (
6666
AutoResetEnv,
6767
AutoResetTransform,
68+
InitTracker,
6869
Tokenizer,
6970
Transform,
7071
UnsqueezeTransform,
@@ -143,6 +144,7 @@
143144
HistoryTransform,
144145
MockBatchedLockedEnv,
145146
MockBatchedUnLockedEnv,
147+
MockNestedResetEnv,
146148
MockSerialEnv,
147149
MultiKeyCountingEnv,
148150
MultiKeyCountingEnvPolicy,
@@ -184,6 +186,7 @@
184186
HistoryTransform,
185187
MockBatchedLockedEnv,
186188
MockBatchedUnLockedEnv,
189+
MockNestedResetEnv,
187190
MockSerialEnv,
188191
MultiKeyCountingEnv,
189192
MultiKeyCountingEnvPolicy,
@@ -2925,6 +2928,62 @@ def test_nested_reset(self, nest_done, has_root_done, batch_size):
29252928
env.rollout(100)
29262929
env.rollout(100, break_when_any_done=False)
29272930

2931+
@pytest.mark.parametrize("done_at_root", [True, False])
2932+
def test_nested_partial_resets(self, maybe_fork_ParallelEnv, done_at_root):
2933+
def make_env(num_steps):
2934+
return MockNestedResetEnv(num_steps, done_at_root)
2935+
2936+
def manual_rollout(env: EnvBase, num_steps: int):
2937+
steps = []
2938+
td = env.reset()
2939+
for _ in range(num_steps):
2940+
td, next_td = env.step_and_maybe_reset(td)
2941+
steps.append(td)
2942+
td = next_td
2943+
return TensorDict.stack(steps)
2944+
2945+
# NOTE: we expect the env[0] to reset after 4 steps, env[1] to reset after 6 steps.
2946+
parallel_env = maybe_fork_ParallelEnv(
2947+
2,
2948+
create_env_fn=make_env,
2949+
create_env_kwargs=[{"num_steps": i} for i in [4, 6]],
2950+
)
2951+
transformed_env = TransformedEnv(
2952+
env=maybe_fork_ParallelEnv(
2953+
2,
2954+
create_env_fn=make_env,
2955+
create_env_kwargs=[{"num_steps": i} for i in [4, 6]],
2956+
),
2957+
transform=InitTracker(),
2958+
)
2959+
2960+
parallel_td = manual_rollout(parallel_env, 6)
2961+
2962+
transformed_td = manual_rollout(transformed_env, 6)
2963+
2964+
# We expect env[0] to have been reset and executed 2 steps.
2965+
# We expect env[1] to have just been reset (0 steps).
2966+
assert parallel_env._counter() == [2, 0]
2967+
assert transformed_env._counter() == [2, 0]
2968+
if done_at_root:
2969+
assert parallel_env._simple_done
2970+
assert transformed_env._simple_done
2971+
# We expect each env to have reached a done state once.
2972+
assert parallel_td["next", "done"].sum().item() == 2
2973+
assert transformed_td["next", "done"].sum().item() == 2
2974+
assert_allclose_td(transformed_td, parallel_td, intersection=True)
2975+
else:
2976+
assert not parallel_env._simple_done
2977+
assert not transformed_env._simple_done
2978+
2979+
assert ("next", "done") not in parallel_td
2980+
assert ("next", "done") not in transformed_td
2981+
assert parallel_td["next", "agent_1", "done"].sum().item() == 2
2982+
assert transformed_td["next", "agent_1", "done"].sum().item() == 2
2983+
assert_allclose_td(transformed_td, parallel_td, intersection=True)
2984+
2985+
assert transformed_env._counter() == [2, 0]
2986+
29282987

29292988
class TestHeteroEnvs:
29302989
@pytest.mark.parametrize("batch_size", [(), (32,), (1, 2)])

torchrl/envs/batched_envs.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -995,9 +995,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
995995
for elt in list_of_kwargs:
996996
elt.update(kwargs)
997997
if tensordict is not None:
998-
needs_resetting = _aggregate_end_of_traj(
999-
tensordict, reset_keys=self.reset_keys
1000-
)
998+
if "_reset" in tensordict.keys():
999+
needs_resetting = tensordict["_reset"]
1000+
else:
1001+
needs_resetting = _aggregate_end_of_traj(
1002+
tensordict, reset_keys=self.reset_keys
1003+
)
10011004
if needs_resetting.ndim > 2:
10021005
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
10031006
if needs_resetting.ndim > 1:
@@ -2114,9 +2117,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
21142117
elt.update(kwargs)
21152118

21162119
if tensordict is not None:
2117-
needs_resetting = _aggregate_end_of_traj(
2118-
tensordict, reset_keys=self.reset_keys
2119-
)
2120+
if "_reset" in tensordict.keys():
2121+
needs_resetting = tensordict["_reset"]
2122+
else:
2123+
needs_resetting = _aggregate_end_of_traj(
2124+
tensordict, reset_keys=self.reset_keys
2125+
)
21202126
if needs_resetting.ndim > 2:
21212127
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
21222128
if needs_resetting.ndim > 1:

torchrl/envs/common.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -2821,12 +2821,25 @@ def _reset_check_done(self, tensordict, tensordict_reset):
28212821
# we iterate over (reset_key, (done_key, truncated_key)) and check that all
28222822
# values where reset was true now have a done set to False.
28232823
# If no reset was present, all done and truncated must be False
2824+
2825+
# Once we checked a root, we don't check its leaves - so keep track of the roots. Fortunately, we sort the done
2826+
# keys in the done_keys_group from root to leaf
2827+
prefix_complete = set()
28242828
for reset_key, done_key_group in zip(self.reset_keys, self.done_keys_groups):
2829+
skip = False
2830+
if isinstance(reset_key, tuple):
2831+
for i in range(len(reset_key) - 1):
2832+
if reset_key[:i] in prefix_complete:
2833+
skip = True
2834+
break
2835+
if skip:
2836+
continue
28252837
reset_value = (
28262838
tensordict.get(reset_key, default=None)
28272839
if tensordict is not None
28282840
else None
28292841
)
2842+
prefix_complete.add(() if isinstance(reset_key, str) else reset_key[:-1])
28302843
if reset_value is not None:
28312844
for done_key in done_key_group:
28322845
done_val = tensordict_reset.get(done_key)
@@ -3580,11 +3593,8 @@ def step_and_maybe_reset(
35803593
@_cache_value
35813594
def _simple_done(self):
35823595
key_set = set(self.full_done_spec.keys())
3583-
_simple_done = key_set == {
3584-
"done",
3585-
"truncated",
3586-
"terminated",
3587-
} or key_set == {"done", "terminated"}
3596+
3597+
_simple_done = "done" in key_set and "terminated" in key_set
35883598
return _simple_done
35893599

35903600
def any_done(self, tensordict: TensorDictBase) -> bool:

0 commit comments

Comments
 (0)