Skip to content

[BugFix] Fix env.full_done_spec~s~ #2815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,15 +1569,10 @@ def _step_and_maybe_reset_no_buffers(

results = [None] * len(workers_range)

consumed_indices = []
events = set(workers_range)
while len(consumed_indices) < len(workers_range):
for i in list(events):
if self._events[i].is_set():
results[i] = self.parent_channels[i].recv()
self._events[i].clear()
consumed_indices.append(i)
events.discard(i)
self._wait_for_workers(workers_range)

for i, w in enumerate(workers_range):
results[i] = self.parent_channels[w].recv()

out_next, out_root = zip(*(future for future in results))
out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(
Expand Down
98 changes: 47 additions & 51 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from typing import Any, Callable, Iterator

import numpy as np
import torch
Expand Down Expand Up @@ -476,7 +476,7 @@ def __init__(
self,
*,
device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
batch_size: torch.Size | None = None,
run_type_checks: bool = False,
allow_done_after_reset: bool = False,
spec_locked: bool = True,
Expand Down Expand Up @@ -587,10 +587,10 @@ def auto_specs_(
policy: Callable[[TensorDictBase], TensorDictBase],
*,
tensordict: TensorDictBase | None = None,
action_key: NestedKey | List[NestedKey] = "action",
done_key: NestedKey | List[NestedKey] | None = None,
observation_key: NestedKey | List[NestedKey] = "observation",
reward_key: NestedKey | List[NestedKey] = "reward",
action_key: NestedKey | list[NestedKey] = "action",
done_key: NestedKey | list[NestedKey] | None = None,
observation_key: NestedKey | list[NestedKey] = "observation",
reward_key: NestedKey | list[NestedKey] = "reward",
):
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.

Expand Down Expand Up @@ -692,7 +692,7 @@ def auto_specs_(
if full_action_spec is not None:
self.full_action_spec = full_action_spec
if full_done_spec is not None:
self.full_done_specs = full_done_spec
self.full_done_spec = full_done_spec
if full_observation_spec is not None:
self.full_observation_spec = full_observation_spec
if full_reward_spec is not None:
Expand All @@ -704,8 +704,7 @@ def auto_specs_(

@wraps(check_env_specs_func)
def check_env_specs(self, *args, **kwargs):
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
kwargs["return_contiguous"] = return_contiguous
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
return check_env_specs_func(self, *args, **kwargs)

check_env_specs.__doc__ = check_env_specs_func.__doc__
Expand Down Expand Up @@ -850,8 +849,7 @@ def ndim(self):

def append_transform(
self,
transform: "Transform" # noqa: F821
| Callable[[TensorDictBase], TensorDictBase],
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
) -> EnvBase:
"""Returns a transformed environment where the callable/transform passed is applied.

Expand Down Expand Up @@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:

@property
@_cache_value
def action_keys(self) -> List[NestedKey]:
def action_keys(self) -> list[NestedKey]:
"""The action keys of an environment.

By default, there will only be one key named "action".
Expand All @@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:

@property
@_cache_value
def state_keys(self) -> List[NestedKey]:
def state_keys(self) -> list[NestedKey]:
"""The state keys of an environment.

By default, there will only be one key named "state".
Expand Down Expand Up @@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
# Reward spec
@property
@_cache_value
def reward_keys(self) -> List[NestedKey]:
def reward_keys(self) -> list[NestedKey]:
"""The reward keys of an environment.

By default, there will only be one key named "reward".
Expand All @@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:

@property
@_cache_value
def observation_keys(self) -> List[NestedKey]:
def observation_keys(self) -> list[NestedKey]:
"""The observation keys of an environment.

By default, there will only be one key named "observation".
Expand Down Expand Up @@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
# done spec
@property
@_cache_value
def done_keys(self) -> List[NestedKey]:
def done_keys(self) -> list[NestedKey]:
"""The done keys of an environment.

By default, there will only be one key named "done".
Expand Down Expand Up @@ -2205,8 +2203,8 @@ def register_gym(
id: str,
*,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
backend: str = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
Expand Down Expand Up @@ -2395,8 +2393,8 @@ def _register_gym(
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2437,8 +2435,8 @@ def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2485,8 +2483,8 @@ def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2538,8 +2536,8 @@ def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2594,8 +2592,8 @@ def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2652,8 +2650,8 @@ def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
transform: Transform | None = None, # noqa: F821
info_keys: list[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
Expand Down Expand Up @@ -2710,7 +2708,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

def reset(
self,
tensordict: Optional[TensorDictBase] = None,
tensordict: TensorDictBase | None = None,
**kwargs,
) -> TensorDictBase:
"""Resets the environment.
Expand Down Expand Up @@ -2819,8 +2817,8 @@ def numel(self) -> int:
return prod(self.batch_size)

def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
self, seed: int | None = None, static_seed: bool = False
) -> int | None:
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).

Args:
Expand All @@ -2841,7 +2839,7 @@ def set_seed(
return seed

@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
def _set_seed(self, seed: int | None):
raise NotImplementedError

def set_state(self):
Expand All @@ -2856,9 +2854,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
f"got {tensordict.batch_size} and {self.batch_size}"
)

def all_actions(
self, tensordict: Optional[TensorDictBase] = None
) -> TensorDictBase:
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
"""Generates all possible actions from the action spec.

This only works in environments with fully discrete actions.
Expand All @@ -2877,7 +2873,7 @@ def all_actions(

return self.full_action_spec.enumerate(use_mask=True)

def rand_action(self, tensordict: Optional[TensorDictBase] = None):
def rand_action(self, tensordict: TensorDictBase | None = None):
"""Performs a random action given the action_spec attribute.

Args:
Expand Down Expand Up @@ -2911,7 +2907,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
tensordict.update(r)
return tensordict

def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
"""Performs a random step in the environment given the action_spec attribute.

Args:
Expand Down Expand Up @@ -2947,15 +2943,15 @@ def _has_dynamic_specs(self) -> bool:
def rollout(
self,
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
callback: Callable[[TensorDictBase, ...], Any] | None = None,
*,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
break_when_any_done: bool | None = None,
break_when_all_done: bool | None = None,
return_contiguous: bool | None = False,
tensordict: Optional[TensorDictBase] = None,
tensordict: TensorDictBase | None = None,
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
Expand Down Expand Up @@ -3485,7 +3481,7 @@ def _rollout_nonstop(

def step_and_maybe_reset(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
) -> tuple[TensorDictBase, TensorDictBase]:
"""Runs a step in the environment and (partially) resets it if needed.

Args:
Expand Down Expand Up @@ -3606,7 +3602,7 @@ def empty_cache(self):

@property
@_cache_value
def reset_keys(self) -> List[NestedKey]:
def reset_keys(self) -> list[NestedKey]:
"""Returns a list of reset keys.

Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
Expand Down Expand Up @@ -3763,14 +3759,14 @@ class _EnvWrapper(EnvBase):
"""

git_url: str = ""
available_envs: Dict[str, Any] = {}
available_envs: dict[str, Any] = {}
libname: str = ""

def __init__(
self,
*args,
device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
batch_size: torch.Size | None = None,
allow_done_after_reset: bool = False,
spec_locked: bool = True,
**kwargs,
Expand Down Expand Up @@ -3819,7 +3815,7 @@ def _sync_device(self):
return sync_func

@abc.abstractmethod
def _check_kwargs(self, kwargs: Dict):
def _check_kwargs(self, kwargs: dict):
raise NotImplementedError

def __getattr__(self, attr: str) -> Any:
Expand All @@ -3845,7 +3841,7 @@ def __getattr__(self, attr: str) -> Any:
)

@abc.abstractmethod
def _init_env(self) -> Optional[int]:
def _init_env(self) -> int | None:
"""Runs all the necessary steps such that the environment is ready to use.

This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
Expand All @@ -3859,7 +3855,7 @@ def _init_env(self) -> Optional[int]:
raise NotImplementedError

@abc.abstractmethod
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
"""Creates an environment from the target library and stores it with the `_env` attribute.

When overwritten, this function should pass all the required kwargs to the env instantiation method.
Expand All @@ -3868,7 +3864,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
raise NotImplementedError

@abc.abstractmethod
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
raise NotImplementedError

def close(self, *, raise_if_closed: bool = True) -> None:
Expand All @@ -3882,7 +3878,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:

def make_tensordict(
env: _EnvWrapper,
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
) -> TensorDictBase:
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.

Expand Down
Loading