Skip to content

Commit 90f37ed

Browse files
committed
[BugFix] Fix env.full_done_spec~s~
ghstack-source-id: ba0d371d10b3f46ec1172fbec639ccc4d5559659 Pull Request resolved: #2815 (cherry picked from commit f5c0666)
1 parent 09640d9 commit 90f37ed

File tree

2 files changed

+82
-56
lines changed

2 files changed

+82
-56
lines changed

Diff for: torchrl/envs/batched_envs.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1521,15 +1521,10 @@ def _step_and_maybe_reset_no_buffers(
15211521

15221522
results = [None] * len(workers_range)
15231523

1524-
consumed_indices = []
1525-
events = set(workers_range)
1526-
while len(consumed_indices) < len(workers_range):
1527-
for i in list(events):
1528-
if self._events[i].is_set():
1529-
results[i] = self.parent_channels[i].recv()
1530-
self._events[i].clear()
1531-
consumed_indices.append(i)
1532-
events.discard(i)
1524+
self._wait_for_workers(workers_range)
1525+
1526+
for i, w in enumerate(workers_range):
1527+
results[i] = self.parent_channels[w].recv()
15331528

15341529
out_next, out_root = zip(*(future for future in results))
15351530
out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(

Diff for: torchrl/envs/common.py

+78-47
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from functools import partial, wraps
12-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
12+
from typing import Any, Callable, Iterator
1313

1414
import numpy as np
1515
import torch
@@ -476,7 +476,7 @@ def __init__(
476476
self,
477477
*,
478478
device: DEVICE_TYPING = None,
479-
batch_size: Optional[torch.Size] = None,
479+
batch_size: torch.Size | None = None,
480480
run_type_checks: bool = False,
481481
allow_done_after_reset: bool = False,
482482
spec_locked: bool = True,
@@ -587,10 +587,10 @@ def auto_specs_(
587587
policy: Callable[[TensorDictBase], TensorDictBase],
588588
*,
589589
tensordict: TensorDictBase | None = None,
590-
action_key: NestedKey | List[NestedKey] = "action",
591-
done_key: NestedKey | List[NestedKey] | None = None,
592-
observation_key: NestedKey | List[NestedKey] = "observation",
593-
reward_key: NestedKey | List[NestedKey] = "reward",
590+
action_key: NestedKey | list[NestedKey] = "action",
591+
done_key: NestedKey | list[NestedKey] | None = None,
592+
observation_key: NestedKey | list[NestedKey] = "observation",
593+
reward_key: NestedKey | list[NestedKey] = "reward",
594594
):
595595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596596
@@ -692,7 +692,7 @@ def auto_specs_(
692692
if full_action_spec is not None:
693693
self.full_action_spec = full_action_spec
694694
if full_done_spec is not None:
695-
self.full_done_specs = full_done_spec
695+
self.full_done_spec = full_done_spec
696696
if full_observation_spec is not None:
697697
self.full_observation_spec = full_observation_spec
698698
if full_reward_spec is not None:
@@ -704,8 +704,7 @@ def auto_specs_(
704704

705705
@wraps(check_env_specs_func)
706706
def check_env_specs(self, *args, **kwargs):
707-
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
708-
kwargs["return_contiguous"] = return_contiguous
707+
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
709708
return check_env_specs_func(self, *args, **kwargs)
710709

711710
check_env_specs.__doc__ = check_env_specs_func.__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850849

851850
def append_transform(
852851
self,
853-
transform: "Transform" # noqa: F821
854-
| Callable[[TensorDictBase], TensorDictBase],
852+
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
855853
) -> EnvBase:
856854
"""Returns a transformed environment where the callable/transform passed is applied.
857855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995993

996994
@property
997995
@_cache_value
998-
def action_keys(self) -> List[NestedKey]:
996+
def action_keys(self) -> list[NestedKey]:
999997
"""The action keys of an environment.
1000998
1001999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
10081006

10091007
@property
10101008
@_cache_value
1011-
def state_keys(self) -> List[NestedKey]:
1009+
def state_keys(self) -> list[NestedKey]:
10121010
"""The state keys of an environment.
10131011
10141012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
12051203
# Reward spec
12061204
@property
12071205
@_cache_value
1208-
def reward_keys(self) -> List[NestedKey]:
1206+
def reward_keys(self) -> list[NestedKey]:
12091207
"""The reward keys of an environment.
12101208
12111209
By default, there will only be one key named "reward".
@@ -1215,6 +1213,20 @@ def reward_keys(self) -> List[NestedKey]:
12151213
reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
12161214
return reward_keys
12171215

1216+
@property
1217+
@_cache_value
1218+
def observation_keys(self) -> list[NestedKey]:
1219+
"""The observation keys of an environment.
1220+
1221+
By default, there will only be one key named "observation".
1222+
1223+
Keys are sorted by depth in the data tree.
1224+
"""
1225+
observation_keys = sorted(
1226+
self.full_observation_spec.keys(True, True), key=_repr_by_depth
1227+
)
1228+
return observation_keys
1229+
12181230
@property
12191231
def reward_key(self):
12201232
"""The reward key of an environment.
@@ -1402,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
14021414
# done spec
14031415
@property
14041416
@_cache_value
1405-
def done_keys(self) -> List[NestedKey]:
1417+
def done_keys(self) -> list[NestedKey]:
14061418
"""The done keys of an environment.
14071419
14081420
By default, there will only be one key named "done".
@@ -2190,8 +2202,8 @@ def register_gym(
21902202
id: str,
21912203
*,
21922204
entry_point: Callable | None = None,
2193-
transform: "Transform" | None = None, # noqa: F821
2194-
info_keys: List[NestedKey] | None = None,
2205+
transform: Transform | None = None, # noqa: F821
2206+
info_keys: list[NestedKey] | None = None,
21952207
backend: str = None,
21962208
to_numpy: bool = False,
21972209
reward_threshold: float | None = None,
@@ -2380,8 +2392,8 @@ def _register_gym(
23802392
cls,
23812393
id,
23822394
entry_point: Callable | None = None,
2383-
transform: "Transform" | None = None, # noqa: F821
2384-
info_keys: List[NestedKey] | None = None,
2395+
transform: Transform | None = None, # noqa: F821
2396+
info_keys: list[NestedKey] | None = None,
23852397
to_numpy: bool = False,
23862398
reward_threshold: float | None = None,
23872399
nondeterministic: bool = False,
@@ -2422,8 +2434,8 @@ def _register_gym( # noqa: F811
24222434
cls,
24232435
id,
24242436
entry_point: Callable | None = None,
2425-
transform: "Transform" | None = None, # noqa: F821
2426-
info_keys: List[NestedKey] | None = None,
2437+
transform: Transform | None = None, # noqa: F821
2438+
info_keys: list[NestedKey] | None = None,
24272439
to_numpy: bool = False,
24282440
reward_threshold: float | None = None,
24292441
nondeterministic: bool = False,
@@ -2470,8 +2482,8 @@ def _register_gym( # noqa: F811
24702482
cls,
24712483
id,
24722484
entry_point: Callable | None = None,
2473-
transform: "Transform" | None = None, # noqa: F821
2474-
info_keys: List[NestedKey] | None = None,
2485+
transform: Transform | None = None, # noqa: F821
2486+
info_keys: list[NestedKey] | None = None,
24752487
to_numpy: bool = False,
24762488
reward_threshold: float | None = None,
24772489
nondeterministic: bool = False,
@@ -2523,8 +2535,8 @@ def _register_gym( # noqa: F811
25232535
cls,
25242536
id,
25252537
entry_point: Callable | None = None,
2526-
transform: "Transform" | None = None, # noqa: F821
2527-
info_keys: List[NestedKey] | None = None,
2538+
transform: Transform | None = None, # noqa: F821
2539+
info_keys: list[NestedKey] | None = None,
25282540
to_numpy: bool = False,
25292541
reward_threshold: float | None = None,
25302542
nondeterministic: bool = False,
@@ -2579,8 +2591,8 @@ def _register_gym( # noqa: F811
25792591
cls,
25802592
id,
25812593
entry_point: Callable | None = None,
2582-
transform: "Transform" | None = None, # noqa: F821
2583-
info_keys: List[NestedKey] | None = None,
2594+
transform: Transform | None = None, # noqa: F821
2595+
info_keys: list[NestedKey] | None = None,
25842596
to_numpy: bool = False,
25852597
reward_threshold: float | None = None,
25862598
nondeterministic: bool = False,
@@ -2637,8 +2649,8 @@ def _register_gym( # noqa: F811
26372649
cls,
26382650
id,
26392651
entry_point: Callable | None = None,
2640-
transform: "Transform" | None = None, # noqa: F821
2641-
info_keys: List[NestedKey] | None = None,
2652+
transform: Transform | None = None, # noqa: F821
2653+
info_keys: list[NestedKey] | None = None,
26422654
to_numpy: bool = False,
26432655
reward_threshold: float | None = None,
26442656
nondeterministic: bool = False,
@@ -2695,7 +2707,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
26952707

26962708
def reset(
26972709
self,
2698-
tensordict: Optional[TensorDictBase] = None,
2710+
tensordict: TensorDictBase | None = None,
26992711
**kwargs,
27002712
) -> TensorDictBase:
27012713
"""Resets the environment.
@@ -2804,8 +2816,8 @@ def numel(self) -> int:
28042816
return prod(self.batch_size)
28052817

28062818
def set_seed(
2807-
self, seed: Optional[int] = None, static_seed: bool = False
2808-
) -> Optional[int]:
2819+
self, seed: int | None = None, static_seed: bool = False
2820+
) -> int | None:
28092821
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
28102822
28112823
Args:
@@ -2826,7 +2838,7 @@ def set_seed(
28262838
return seed
28272839

28282840
@abc.abstractmethod
2829-
def _set_seed(self, seed: Optional[int]):
2841+
def _set_seed(self, seed: int | None):
28302842
raise NotImplementedError
28312843

28322844
def set_state(self):
@@ -2841,7 +2853,26 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
28412853
f"got {tensordict.batch_size} and {self.batch_size}"
28422854
)
28432855

2844-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2856+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
2857+
"""Generates all possible actions from the action spec.
2858+
2859+
This only works in environments with fully discrete actions.
2860+
2861+
Args:
2862+
tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2863+
is called with this tensordict.
2864+
2865+
Returns:
2866+
a tensordict object with the "action" entry updated with a batch of
2867+
all possible actions. The actions are stacked together in the
2868+
leading dimension.
2869+
"""
2870+
if tensordict is not None:
2871+
self.reset(tensordict)
2872+
2873+
return self.full_action_spec.enumerate(use_mask=True)
2874+
2875+
def rand_action(self, tensordict: TensorDictBase | None = None):
28452876
"""Performs a random action given the action_spec attribute.
28462877
28472878
Args:
@@ -2875,7 +2906,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
28752906
tensordict.update(r)
28762907
return tensordict
28772908

2878-
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
2909+
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
28792910
"""Performs a random step in the environment given the action_spec attribute.
28802911
28812912
Args:
@@ -2911,15 +2942,15 @@ def _has_dynamic_specs(self) -> bool:
29112942
def rollout(
29122943
self,
29132944
max_steps: int,
2914-
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
2915-
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
2945+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
2946+
callback: Callable[[TensorDictBase, ...], Any] | None = None,
29162947
*,
29172948
auto_reset: bool = True,
29182949
auto_cast_to_device: bool = False,
29192950
break_when_any_done: bool | None = None,
29202951
break_when_all_done: bool | None = None,
29212952
return_contiguous: bool | None = False,
2922-
tensordict: Optional[TensorDictBase] = None,
2953+
tensordict: TensorDictBase | None = None,
29232954
set_truncated: bool = False,
29242955
out=None,
29252956
trust_policy: bool = False,
@@ -3441,7 +3472,7 @@ def _rollout_nonstop(
34413472

34423473
def step_and_maybe_reset(
34433474
self, tensordict: TensorDictBase
3444-
) -> Tuple[TensorDictBase, TensorDictBase]:
3475+
) -> tuple[TensorDictBase, TensorDictBase]:
34453476
"""Runs a step in the environment and (partially) resets it if needed.
34463477
34473478
Args:
@@ -3549,7 +3580,7 @@ def empty_cache(self):
35493580

35503581
@property
35513582
@_cache_value
3552-
def reset_keys(self) -> List[NestedKey]:
3583+
def reset_keys(self) -> list[NestedKey]:
35533584
"""Returns a list of reset keys.
35543585
35553586
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3706,14 +3737,14 @@ class _EnvWrapper(EnvBase):
37063737
"""
37073738

37083739
git_url: str = ""
3709-
available_envs: Dict[str, Any] = {}
3740+
available_envs: dict[str, Any] = {}
37103741
libname: str = ""
37113742

37123743
def __init__(
37133744
self,
37143745
*args,
37153746
device: DEVICE_TYPING = None,
3716-
batch_size: Optional[torch.Size] = None,
3747+
batch_size: torch.Size | None = None,
37173748
allow_done_after_reset: bool = False,
37183749
spec_locked: bool = True,
37193750
**kwargs,
@@ -3762,7 +3793,7 @@ def _sync_device(self):
37623793
return sync_func
37633794

37643795
@abc.abstractmethod
3765-
def _check_kwargs(self, kwargs: Dict):
3796+
def _check_kwargs(self, kwargs: dict):
37663797
raise NotImplementedError
37673798

37683799
def __getattr__(self, attr: str) -> Any:
@@ -3788,7 +3819,7 @@ def __getattr__(self, attr: str) -> Any:
37883819
)
37893820

37903821
@abc.abstractmethod
3791-
def _init_env(self) -> Optional[int]:
3822+
def _init_env(self) -> int | None:
37923823
"""Runs all the necessary steps such that the environment is ready to use.
37933824
37943825
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3802,7 +3833,7 @@ def _init_env(self) -> Optional[int]:
38023833
raise NotImplementedError
38033834

38043835
@abc.abstractmethod
3805-
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3836+
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
38063837
"""Creates an environment from the target library and stores it with the `_env` attribute.
38073838
38083839
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3811,7 +3842,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
38113842
raise NotImplementedError
38123843

38133844
@abc.abstractmethod
3814-
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
3845+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
38153846
raise NotImplementedError
38163847

38173848
def close(self) -> None:
@@ -3825,7 +3856,7 @@ def close(self) -> None:
38253856

38263857
def make_tensordict(
38273858
env: _EnvWrapper,
3828-
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
3859+
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
38293860
) -> TensorDictBase:
38303861
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
38313862

0 commit comments

Comments
 (0)