Skip to content

Commit f5c0666

Browse files
committed
[BugFix] Fix env.full_done_spec~s~
ghstack-source-id: ba0d371d10b3f46ec1172fbec639ccc4d5559659 Pull Request resolved: #2815
1 parent 59e8545 commit f5c0666

File tree

2 files changed

+51
-60
lines changed

2 files changed

+51
-60
lines changed

Diff for: torchrl/envs/batched_envs.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1569,15 +1569,10 @@ def _step_and_maybe_reset_no_buffers(
15691569

15701570
results = [None] * len(workers_range)
15711571

1572-
consumed_indices = []
1573-
events = set(workers_range)
1574-
while len(consumed_indices) < len(workers_range):
1575-
for i in list(events):
1576-
if self._events[i].is_set():
1577-
results[i] = self.parent_channels[i].recv()
1578-
self._events[i].clear()
1579-
consumed_indices.append(i)
1580-
events.discard(i)
1572+
self._wait_for_workers(workers_range)
1573+
1574+
for i, w in enumerate(workers_range):
1575+
results[i] = self.parent_channels[w].recv()
15811576

15821577
out_next, out_root = zip(*(future for future in results))
15831578
out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(

Diff for: torchrl/envs/common.py

+47-51
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from functools import partial, wraps
12-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
12+
from typing import Any, Callable, Iterator
1313

1414
import numpy as np
1515
import torch
@@ -476,7 +476,7 @@ def __init__(
476476
self,
477477
*,
478478
device: DEVICE_TYPING = None,
479-
batch_size: Optional[torch.Size] = None,
479+
batch_size: torch.Size | None = None,
480480
run_type_checks: bool = False,
481481
allow_done_after_reset: bool = False,
482482
spec_locked: bool = True,
@@ -587,10 +587,10 @@ def auto_specs_(
587587
policy: Callable[[TensorDictBase], TensorDictBase],
588588
*,
589589
tensordict: TensorDictBase | None = None,
590-
action_key: NestedKey | List[NestedKey] = "action",
591-
done_key: NestedKey | List[NestedKey] | None = None,
592-
observation_key: NestedKey | List[NestedKey] = "observation",
593-
reward_key: NestedKey | List[NestedKey] = "reward",
590+
action_key: NestedKey | list[NestedKey] = "action",
591+
done_key: NestedKey | list[NestedKey] | None = None,
592+
observation_key: NestedKey | list[NestedKey] = "observation",
593+
reward_key: NestedKey | list[NestedKey] = "reward",
594594
):
595595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596596
@@ -692,7 +692,7 @@ def auto_specs_(
692692
if full_action_spec is not None:
693693
self.full_action_spec = full_action_spec
694694
if full_done_spec is not None:
695-
self.full_done_specs = full_done_spec
695+
self.full_done_spec = full_done_spec
696696
if full_observation_spec is not None:
697697
self.full_observation_spec = full_observation_spec
698698
if full_reward_spec is not None:
@@ -704,8 +704,7 @@ def auto_specs_(
704704

705705
@wraps(check_env_specs_func)
706706
def check_env_specs(self, *args, **kwargs):
707-
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
708-
kwargs["return_contiguous"] = return_contiguous
707+
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
709708
return check_env_specs_func(self, *args, **kwargs)
710709

711710
check_env_specs.__doc__ = check_env_specs_func.__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850849

851850
def append_transform(
852851
self,
853-
transform: "Transform" # noqa: F821
854-
| Callable[[TensorDictBase], TensorDictBase],
852+
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
855853
) -> EnvBase:
856854
"""Returns a transformed environment where the callable/transform passed is applied.
857855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995993

996994
@property
997995
@_cache_value
998-
def action_keys(self) -> List[NestedKey]:
996+
def action_keys(self) -> list[NestedKey]:
999997
"""The action keys of an environment.
1000998
1001999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
10081006

10091007
@property
10101008
@_cache_value
1011-
def state_keys(self) -> List[NestedKey]:
1009+
def state_keys(self) -> list[NestedKey]:
10121010
"""The state keys of an environment.
10131011
10141012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
12051203
# Reward spec
12061204
@property
12071205
@_cache_value
1208-
def reward_keys(self) -> List[NestedKey]:
1206+
def reward_keys(self) -> list[NestedKey]:
12091207
"""The reward keys of an environment.
12101208
12111209
By default, there will only be one key named "reward".
@@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:
12171215

12181216
@property
12191217
@_cache_value
1220-
def observation_keys(self) -> List[NestedKey]:
1218+
def observation_keys(self) -> list[NestedKey]:
12211219
"""The observation keys of an environment.
12221220
12231221
By default, there will only be one key named "observation".
@@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
14161414
# done spec
14171415
@property
14181416
@_cache_value
1419-
def done_keys(self) -> List[NestedKey]:
1417+
def done_keys(self) -> list[NestedKey]:
14201418
"""The done keys of an environment.
14211419
14221420
By default, there will only be one key named "done".
@@ -2205,8 +2203,8 @@ def register_gym(
22052203
id: str,
22062204
*,
22072205
entry_point: Callable | None = None,
2208-
transform: "Transform" | None = None, # noqa: F821
2209-
info_keys: List[NestedKey] | None = None,
2206+
transform: Transform | None = None, # noqa: F821
2207+
info_keys: list[NestedKey] | None = None,
22102208
backend: str = None,
22112209
to_numpy: bool = False,
22122210
reward_threshold: float | None = None,
@@ -2395,8 +2393,8 @@ def _register_gym(
23952393
cls,
23962394
id,
23972395
entry_point: Callable | None = None,
2398-
transform: "Transform" | None = None, # noqa: F821
2399-
info_keys: List[NestedKey] | None = None,
2396+
transform: Transform | None = None, # noqa: F821
2397+
info_keys: list[NestedKey] | None = None,
24002398
to_numpy: bool = False,
24012399
reward_threshold: float | None = None,
24022400
nondeterministic: bool = False,
@@ -2437,8 +2435,8 @@ def _register_gym( # noqa: F811
24372435
cls,
24382436
id,
24392437
entry_point: Callable | None = None,
2440-
transform: "Transform" | None = None, # noqa: F821
2441-
info_keys: List[NestedKey] | None = None,
2438+
transform: Transform | None = None, # noqa: F821
2439+
info_keys: list[NestedKey] | None = None,
24422440
to_numpy: bool = False,
24432441
reward_threshold: float | None = None,
24442442
nondeterministic: bool = False,
@@ -2485,8 +2483,8 @@ def _register_gym( # noqa: F811
24852483
cls,
24862484
id,
24872485
entry_point: Callable | None = None,
2488-
transform: "Transform" | None = None, # noqa: F821
2489-
info_keys: List[NestedKey] | None = None,
2486+
transform: Transform | None = None, # noqa: F821
2487+
info_keys: list[NestedKey] | None = None,
24902488
to_numpy: bool = False,
24912489
reward_threshold: float | None = None,
24922490
nondeterministic: bool = False,
@@ -2538,8 +2536,8 @@ def _register_gym( # noqa: F811
25382536
cls,
25392537
id,
25402538
entry_point: Callable | None = None,
2541-
transform: "Transform" | None = None, # noqa: F821
2542-
info_keys: List[NestedKey] | None = None,
2539+
transform: Transform | None = None, # noqa: F821
2540+
info_keys: list[NestedKey] | None = None,
25432541
to_numpy: bool = False,
25442542
reward_threshold: float | None = None,
25452543
nondeterministic: bool = False,
@@ -2594,8 +2592,8 @@ def _register_gym( # noqa: F811
25942592
cls,
25952593
id,
25962594
entry_point: Callable | None = None,
2597-
transform: "Transform" | None = None, # noqa: F821
2598-
info_keys: List[NestedKey] | None = None,
2595+
transform: Transform | None = None, # noqa: F821
2596+
info_keys: list[NestedKey] | None = None,
25992597
to_numpy: bool = False,
26002598
reward_threshold: float | None = None,
26012599
nondeterministic: bool = False,
@@ -2652,8 +2650,8 @@ def _register_gym( # noqa: F811
26522650
cls,
26532651
id,
26542652
entry_point: Callable | None = None,
2655-
transform: "Transform" | None = None, # noqa: F821
2656-
info_keys: List[NestedKey] | None = None,
2653+
transform: Transform | None = None, # noqa: F821
2654+
info_keys: list[NestedKey] | None = None,
26572655
to_numpy: bool = False,
26582656
reward_threshold: float | None = None,
26592657
nondeterministic: bool = False,
@@ -2710,7 +2708,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
27102708

27112709
def reset(
27122710
self,
2713-
tensordict: Optional[TensorDictBase] = None,
2711+
tensordict: TensorDictBase | None = None,
27142712
**kwargs,
27152713
) -> TensorDictBase:
27162714
"""Resets the environment.
@@ -2819,8 +2817,8 @@ def numel(self) -> int:
28192817
return prod(self.batch_size)
28202818

28212819
def set_seed(
2822-
self, seed: Optional[int] = None, static_seed: bool = False
2823-
) -> Optional[int]:
2820+
self, seed: int | None = None, static_seed: bool = False
2821+
) -> int | None:
28242822
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
28252823
28262824
Args:
@@ -2841,7 +2839,7 @@ def set_seed(
28412839
return seed
28422840

28432841
@abc.abstractmethod
2844-
def _set_seed(self, seed: Optional[int]):
2842+
def _set_seed(self, seed: int | None):
28452843
raise NotImplementedError
28462844

28472845
def set_state(self):
@@ -2856,9 +2854,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
28562854
f"got {tensordict.batch_size} and {self.batch_size}"
28572855
)
28582856

2859-
def all_actions(
2860-
self, tensordict: Optional[TensorDictBase] = None
2861-
) -> TensorDictBase:
2857+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
28622858
"""Generates all possible actions from the action spec.
28632859
28642860
This only works in environments with fully discrete actions.
@@ -2877,7 +2873,7 @@ def all_actions(
28772873

28782874
return self.full_action_spec.enumerate(use_mask=True)
28792875

2880-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2876+
def rand_action(self, tensordict: TensorDictBase | None = None):
28812877
"""Performs a random action given the action_spec attribute.
28822878
28832879
Args:
@@ -2911,7 +2907,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
29112907
tensordict.update(r)
29122908
return tensordict
29132909

2914-
def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
2910+
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
29152911
"""Performs a random step in the environment given the action_spec attribute.
29162912
29172913
Args:
@@ -2947,15 +2943,15 @@ def _has_dynamic_specs(self) -> bool:
29472943
def rollout(
29482944
self,
29492945
max_steps: int,
2950-
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
2951-
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
2946+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
2947+
callback: Callable[[TensorDictBase, ...], Any] | None = None,
29522948
*,
29532949
auto_reset: bool = True,
29542950
auto_cast_to_device: bool = False,
29552951
break_when_any_done: bool | None = None,
29562952
break_when_all_done: bool | None = None,
29572953
return_contiguous: bool | None = False,
2958-
tensordict: Optional[TensorDictBase] = None,
2954+
tensordict: TensorDictBase | None = None,
29592955
set_truncated: bool = False,
29602956
out=None,
29612957
trust_policy: bool = False,
@@ -3485,7 +3481,7 @@ def _rollout_nonstop(
34853481

34863482
def step_and_maybe_reset(
34873483
self, tensordict: TensorDictBase
3488-
) -> Tuple[TensorDictBase, TensorDictBase]:
3484+
) -> tuple[TensorDictBase, TensorDictBase]:
34893485
"""Runs a step in the environment and (partially) resets it if needed.
34903486
34913487
Args:
@@ -3606,7 +3602,7 @@ def empty_cache(self):
36063602

36073603
@property
36083604
@_cache_value
3609-
def reset_keys(self) -> List[NestedKey]:
3605+
def reset_keys(self) -> list[NestedKey]:
36103606
"""Returns a list of reset keys.
36113607
36123608
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3763,14 +3759,14 @@ class _EnvWrapper(EnvBase):
37633759
"""
37643760

37653761
git_url: str = ""
3766-
available_envs: Dict[str, Any] = {}
3762+
available_envs: dict[str, Any] = {}
37673763
libname: str = ""
37683764

37693765
def __init__(
37703766
self,
37713767
*args,
37723768
device: DEVICE_TYPING = None,
3773-
batch_size: Optional[torch.Size] = None,
3769+
batch_size: torch.Size | None = None,
37743770
allow_done_after_reset: bool = False,
37753771
spec_locked: bool = True,
37763772
**kwargs,
@@ -3819,7 +3815,7 @@ def _sync_device(self):
38193815
return sync_func
38203816

38213817
@abc.abstractmethod
3822-
def _check_kwargs(self, kwargs: Dict):
3818+
def _check_kwargs(self, kwargs: dict):
38233819
raise NotImplementedError
38243820

38253821
def __getattr__(self, attr: str) -> Any:
@@ -3845,7 +3841,7 @@ def __getattr__(self, attr: str) -> Any:
38453841
)
38463842

38473843
@abc.abstractmethod
3848-
def _init_env(self) -> Optional[int]:
3844+
def _init_env(self) -> int | None:
38493845
"""Runs all the necessary steps such that the environment is ready to use.
38503846
38513847
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3859,7 +3855,7 @@ def _init_env(self) -> Optional[int]:
38593855
raise NotImplementedError
38603856

38613857
@abc.abstractmethod
3862-
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3858+
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
38633859
"""Creates an environment from the target library and stores it with the `_env` attribute.
38643860
38653861
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3868,7 +3864,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
38683864
raise NotImplementedError
38693865

38703866
@abc.abstractmethod
3871-
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
3867+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
38723868
raise NotImplementedError
38733869

38743870
def close(self, *, raise_if_closed: bool = True) -> None:
@@ -3882,7 +3878,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:
38823878

38833879
def make_tensordict(
38843880
env: _EnvWrapper,
3885-
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
3881+
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
38863882
) -> TensorDictBase:
38873883
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
38883884

0 commit comments

Comments
 (0)