From c20a4f1b410d052d69a549a7f1fad788877f829c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 14:38:16 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/envs/common.py | 96 ++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c8027ecd316..262335d064d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -476,7 +476,7 @@ def __init__( self, *, device: DEVICE_TYPING = None, - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, run_type_checks: bool = False, allow_done_after_reset: bool = False, spec_locked: bool = True, @@ -587,10 +587,10 @@ def auto_specs_( policy: Callable[[TensorDictBase], TensorDictBase], *, tensordict: TensorDictBase | None = None, - action_key: NestedKey | List[NestedKey] = "action", - done_key: NestedKey | List[NestedKey] | None = None, - observation_key: NestedKey | List[NestedKey] = "observation", - reward_key: NestedKey | List[NestedKey] = "reward", + action_key: NestedKey | list[NestedKey] = "action", + done_key: NestedKey | list[NestedKey] | None = None, + observation_key: NestedKey | list[NestedKey] = "observation", + reward_key: NestedKey | list[NestedKey] = "reward", ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -692,7 +692,7 @@ def auto_specs_( if full_action_spec is not None: self.full_action_spec = full_action_spec if full_done_spec is not None: - self.full_done_specs = full_done_spec + self.full_done_spec = full_done_spec if full_observation_spec is not None: self.full_observation_spec = full_observation_spec if full_reward_spec is not None: @@ -704,8 +704,7 @@ def auto_specs_( @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): - return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs) - kwargs["return_contiguous"] = return_contiguous + kwargs.setdefault("return_contiguous", not self._has_dynamic_specs) return check_env_specs_func(self, *args, **kwargs) check_env_specs.__doc__ = check_env_specs_func.__doc__ @@ -850,8 +849,7 @@ def ndim(self): def append_transform( self, - transform: "Transform" # noqa: F821 - | Callable[[TensorDictBase], TensorDictBase], + transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821 ) -> EnvBase: """Returns a transformed environment where the callable/transform passed is applied. @@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None: @property @_cache_value - def action_keys(self) -> List[NestedKey]: + def action_keys(self) -> list[NestedKey]: """The action keys of an environment. By default, there will only be one key named "action". @@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]: @property @_cache_value - def state_keys(self) -> List[NestedKey]: + def state_keys(self) -> list[NestedKey]: """The state keys of an environment. By default, there will only be one key named "state". @@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None: # Reward spec @property @_cache_value - def reward_keys(self) -> List[NestedKey]: + def reward_keys(self) -> list[NestedKey]: """The reward keys of an environment. By default, there will only be one key named "reward". @@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]: @property @_cache_value - def observation_keys(self) -> List[NestedKey]: + def observation_keys(self) -> list[NestedKey]: """The observation keys of an environment. By default, there will only be one key named "observation". @@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None: # done spec @property @_cache_value - def done_keys(self) -> List[NestedKey]: + def done_keys(self) -> list[NestedKey]: """The done keys of an environment. By default, there will only be one key named "done". @@ -2202,8 +2200,8 @@ def register_gym( id: str, *, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, backend: str = None, to_numpy: bool = False, reward_threshold: float | None = None, @@ -2392,8 +2390,8 @@ def _register_gym( cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2434,8 +2432,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2482,8 +2480,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2535,8 +2533,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2591,8 +2589,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2649,8 +2647,8 @@ def _register_gym( # noqa: F811 cls, id, entry_point: Callable | None = None, - transform: "Transform" | None = None, # noqa: F821 - info_keys: List[NestedKey] | None = None, + transform: Transform | None = None, # noqa: F821 + info_keys: list[NestedKey] | None = None, to_numpy: bool = False, reward_threshold: float | None = None, nondeterministic: bool = False, @@ -2707,7 +2705,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: def reset( self, - tensordict: Optional[TensorDictBase] = None, + tensordict: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: """Resets the environment. @@ -2816,8 +2814,8 @@ def numel(self) -> int: return prod(self.batch_size) def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present). Args: @@ -2838,7 +2836,7 @@ def set_seed( return seed @abc.abstractmethod - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): raise NotImplementedError def set_state(self): @@ -2853,9 +2851,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: f"got {tensordict.batch_size} and {self.batch_size}" ) - def all_actions( - self, tensordict: Optional[TensorDictBase] = None - ) -> TensorDictBase: + def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: """Generates all possible actions from the action spec. This only works in environments with fully discrete actions. @@ -2874,7 +2870,7 @@ def all_actions( return self.full_action_spec.enumerate(use_mask=True) - def rand_action(self, tensordict: Optional[TensorDictBase] = None): + def rand_action(self, tensordict: TensorDictBase | None = None): """Performs a random action given the action_spec attribute. Args: @@ -2908,7 +2904,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): tensordict.update(r) return tensordict - def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: + def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase: """Performs a random step in the environment given the action_spec attribute. Args: @@ -2944,15 +2940,15 @@ def _has_dynamic_specs(self) -> bool: def rollout( self, max_steps: int, - policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, + callback: Callable[[TensorDictBase, ...], Any] | None = None, *, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool | None = None, break_when_all_done: bool | None = None, return_contiguous: bool | None = False, - tensordict: Optional[TensorDictBase] = None, + tensordict: TensorDictBase | None = None, set_truncated: bool = False, out=None, trust_policy: bool = False, @@ -3479,7 +3475,7 @@ def _rollout_nonstop( def step_and_maybe_reset( self, tensordict: TensorDictBase - ) -> Tuple[TensorDictBase, TensorDictBase]: + ) -> tuple[TensorDictBase, TensorDictBase]: """Runs a step in the environment and (partially) resets it if needed. Args: @@ -3600,7 +3596,7 @@ def empty_cache(self): @property @_cache_value - def reset_keys(self) -> List[NestedKey]: + def reset_keys(self) -> list[NestedKey]: """Returns a list of reset keys. Reset keys are keys that indicate partial reset, in batched, multitask or multiagent @@ -3757,14 +3753,14 @@ class _EnvWrapper(EnvBase): """ git_url: str = "" - available_envs: Dict[str, Any] = {} + available_envs: dict[str, Any] = {} libname: str = "" def __init__( self, *args, device: DEVICE_TYPING = None, - batch_size: Optional[torch.Size] = None, + batch_size: torch.Size | None = None, allow_done_after_reset: bool = False, spec_locked: bool = True, **kwargs, @@ -3813,7 +3809,7 @@ def _sync_device(self): return sync_func @abc.abstractmethod - def _check_kwargs(self, kwargs: Dict): + def _check_kwargs(self, kwargs: dict): raise NotImplementedError def __getattr__(self, attr: str) -> Any: @@ -3839,7 +3835,7 @@ def __getattr__(self, attr: str) -> Any: ) @abc.abstractmethod - def _init_env(self) -> Optional[int]: + def _init_env(self) -> int | None: """Runs all the necessary steps such that the environment is ready to use. This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment @@ -3853,7 +3849,7 @@ def _init_env(self) -> Optional[int]: raise NotImplementedError @abc.abstractmethod - def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821 + def _build_env(self, **kwargs) -> gym.Env: # noqa: F821 """Creates an environment from the target library and stores it with the `_env` attribute. When overwritten, this function should pass all the required kwargs to the env instantiation method. @@ -3862,7 +3858,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821 raise NotImplementedError @abc.abstractmethod - def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 + def _make_specs(self, env: gym.Env) -> None: # noqa: F821 raise NotImplementedError def close(self, *, raise_if_closed: bool = True) -> None: @@ -3876,7 +3872,7 @@ def close(self, *, raise_if_closed: bool = True) -> None: def make_tensordict( env: _EnvWrapper, - policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, + policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None, ) -> TensorDictBase: """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.