9
9
import warnings
10
10
from copy import deepcopy
11
11
from functools import partial , wraps
12
- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple
12
+ from typing import Any , Callable , Iterator
13
13
14
14
import numpy as np
15
15
import torch
@@ -476,7 +476,7 @@ def __init__(
476
476
self ,
477
477
* ,
478
478
device : DEVICE_TYPING = None ,
479
- batch_size : Optional [ torch .Size ] = None ,
479
+ batch_size : torch .Size | None = None ,
480
480
run_type_checks : bool = False ,
481
481
allow_done_after_reset : bool = False ,
482
482
spec_locked : bool = True ,
@@ -587,10 +587,10 @@ def auto_specs_(
587
587
policy : Callable [[TensorDictBase ], TensorDictBase ],
588
588
* ,
589
589
tensordict : TensorDictBase | None = None ,
590
- action_key : NestedKey | List [NestedKey ] = "action" ,
591
- done_key : NestedKey | List [NestedKey ] | None = None ,
592
- observation_key : NestedKey | List [NestedKey ] = "observation" ,
593
- reward_key : NestedKey | List [NestedKey ] = "reward" ,
590
+ action_key : NestedKey | list [NestedKey ] = "action" ,
591
+ done_key : NestedKey | list [NestedKey ] | None = None ,
592
+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
593
+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
594
594
):
595
595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596
596
@@ -692,7 +692,7 @@ def auto_specs_(
692
692
if full_action_spec is not None :
693
693
self .full_action_spec = full_action_spec
694
694
if full_done_spec is not None :
695
- self .full_done_specs = full_done_spec
695
+ self .full_done_spec = full_done_spec
696
696
if full_observation_spec is not None :
697
697
self .full_observation_spec = full_observation_spec
698
698
if full_reward_spec is not None :
@@ -704,8 +704,7 @@ def auto_specs_(
704
704
705
705
@wraps (check_env_specs_func )
706
706
def check_env_specs (self , * args , ** kwargs ):
707
- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
708
- kwargs ["return_contiguous" ] = return_contiguous
707
+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
709
708
return check_env_specs_func (self , * args , ** kwargs )
710
709
711
710
check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850
849
851
850
def append_transform (
852
851
self ,
853
- transform : "Transform" # noqa: F821
854
- | Callable [[TensorDictBase ], TensorDictBase ],
852
+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
855
853
) -> EnvBase :
856
854
"""Returns a transformed environment where the callable/transform passed is applied.
857
855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995
993
996
994
@property
997
995
@_cache_value
998
- def action_keys (self ) -> List [NestedKey ]:
996
+ def action_keys (self ) -> list [NestedKey ]:
999
997
"""The action keys of an environment.
1000
998
1001
999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
1008
1006
1009
1007
@property
1010
1008
@_cache_value
1011
- def state_keys (self ) -> List [NestedKey ]:
1009
+ def state_keys (self ) -> list [NestedKey ]:
1012
1010
"""The state keys of an environment.
1013
1011
1014
1012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
1205
1203
# Reward spec
1206
1204
@property
1207
1205
@_cache_value
1208
- def reward_keys (self ) -> List [NestedKey ]:
1206
+ def reward_keys (self ) -> list [NestedKey ]:
1209
1207
"""The reward keys of an environment.
1210
1208
1211
1209
By default, there will only be one key named "reward".
@@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:
1217
1215
1218
1216
@property
1219
1217
@_cache_value
1220
- def observation_keys (self ) -> List [NestedKey ]:
1218
+ def observation_keys (self ) -> list [NestedKey ]:
1221
1219
"""The observation keys of an environment.
1222
1220
1223
1221
By default, there will only be one key named "observation".
@@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
1416
1414
# done spec
1417
1415
@property
1418
1416
@_cache_value
1419
- def done_keys (self ) -> List [NestedKey ]:
1417
+ def done_keys (self ) -> list [NestedKey ]:
1420
1418
"""The done keys of an environment.
1421
1419
1422
1420
By default, there will only be one key named "done".
@@ -2205,8 +2203,8 @@ def register_gym(
2205
2203
id : str ,
2206
2204
* ,
2207
2205
entry_point : Callable | None = None ,
2208
- transform : " Transform" | None = None , # noqa: F821
2209
- info_keys : List [NestedKey ] | None = None ,
2206
+ transform : Transform | None = None , # noqa: F821
2207
+ info_keys : list [NestedKey ] | None = None ,
2210
2208
backend : str = None ,
2211
2209
to_numpy : bool = False ,
2212
2210
reward_threshold : float | None = None ,
@@ -2395,8 +2393,8 @@ def _register_gym(
2395
2393
cls ,
2396
2394
id ,
2397
2395
entry_point : Callable | None = None ,
2398
- transform : " Transform" | None = None , # noqa: F821
2399
- info_keys : List [NestedKey ] | None = None ,
2396
+ transform : Transform | None = None , # noqa: F821
2397
+ info_keys : list [NestedKey ] | None = None ,
2400
2398
to_numpy : bool = False ,
2401
2399
reward_threshold : float | None = None ,
2402
2400
nondeterministic : bool = False ,
@@ -2437,8 +2435,8 @@ def _register_gym( # noqa: F811
2437
2435
cls ,
2438
2436
id ,
2439
2437
entry_point : Callable | None = None ,
2440
- transform : " Transform" | None = None , # noqa: F821
2441
- info_keys : List [NestedKey ] | None = None ,
2438
+ transform : Transform | None = None , # noqa: F821
2439
+ info_keys : list [NestedKey ] | None = None ,
2442
2440
to_numpy : bool = False ,
2443
2441
reward_threshold : float | None = None ,
2444
2442
nondeterministic : bool = False ,
@@ -2485,8 +2483,8 @@ def _register_gym( # noqa: F811
2485
2483
cls ,
2486
2484
id ,
2487
2485
entry_point : Callable | None = None ,
2488
- transform : " Transform" | None = None , # noqa: F821
2489
- info_keys : List [NestedKey ] | None = None ,
2486
+ transform : Transform | None = None , # noqa: F821
2487
+ info_keys : list [NestedKey ] | None = None ,
2490
2488
to_numpy : bool = False ,
2491
2489
reward_threshold : float | None = None ,
2492
2490
nondeterministic : bool = False ,
@@ -2538,8 +2536,8 @@ def _register_gym( # noqa: F811
2538
2536
cls ,
2539
2537
id ,
2540
2538
entry_point : Callable | None = None ,
2541
- transform : " Transform" | None = None , # noqa: F821
2542
- info_keys : List [NestedKey ] | None = None ,
2539
+ transform : Transform | None = None , # noqa: F821
2540
+ info_keys : list [NestedKey ] | None = None ,
2543
2541
to_numpy : bool = False ,
2544
2542
reward_threshold : float | None = None ,
2545
2543
nondeterministic : bool = False ,
@@ -2594,8 +2592,8 @@ def _register_gym( # noqa: F811
2594
2592
cls ,
2595
2593
id ,
2596
2594
entry_point : Callable | None = None ,
2597
- transform : " Transform" | None = None , # noqa: F821
2598
- info_keys : List [NestedKey ] | None = None ,
2595
+ transform : Transform | None = None , # noqa: F821
2596
+ info_keys : list [NestedKey ] | None = None ,
2599
2597
to_numpy : bool = False ,
2600
2598
reward_threshold : float | None = None ,
2601
2599
nondeterministic : bool = False ,
@@ -2652,8 +2650,8 @@ def _register_gym( # noqa: F811
2652
2650
cls ,
2653
2651
id ,
2654
2652
entry_point : Callable | None = None ,
2655
- transform : " Transform" | None = None , # noqa: F821
2656
- info_keys : List [NestedKey ] | None = None ,
2653
+ transform : Transform | None = None , # noqa: F821
2654
+ info_keys : list [NestedKey ] | None = None ,
2657
2655
to_numpy : bool = False ,
2658
2656
reward_threshold : float | None = None ,
2659
2657
nondeterministic : bool = False ,
@@ -2710,7 +2708,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2710
2708
2711
2709
def reset (
2712
2710
self ,
2713
- tensordict : Optional [ TensorDictBase ] = None ,
2711
+ tensordict : TensorDictBase | None = None ,
2714
2712
** kwargs ,
2715
2713
) -> TensorDictBase :
2716
2714
"""Resets the environment.
@@ -2819,8 +2817,8 @@ def numel(self) -> int:
2819
2817
return prod (self .batch_size )
2820
2818
2821
2819
def set_seed (
2822
- self , seed : Optional [ int ] = None , static_seed : bool = False
2823
- ) -> Optional [ int ] :
2820
+ self , seed : int | None = None , static_seed : bool = False
2821
+ ) -> int | None :
2824
2822
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
2825
2823
2826
2824
Args:
@@ -2841,7 +2839,7 @@ def set_seed(
2841
2839
return seed
2842
2840
2843
2841
@abc .abstractmethod
2844
- def _set_seed (self , seed : Optional [ int ] ):
2842
+ def _set_seed (self , seed : int | None ):
2845
2843
raise NotImplementedError
2846
2844
2847
2845
def set_state (self ):
@@ -2856,9 +2854,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
2856
2854
f"got { tensordict .batch_size } and { self .batch_size } "
2857
2855
)
2858
2856
2859
- def all_actions (
2860
- self , tensordict : Optional [TensorDictBase ] = None
2861
- ) -> TensorDictBase :
2857
+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2862
2858
"""Generates all possible actions from the action spec.
2863
2859
2864
2860
This only works in environments with fully discrete actions.
@@ -2877,7 +2873,7 @@ def all_actions(
2877
2873
2878
2874
return self .full_action_spec .enumerate (use_mask = True )
2879
2875
2880
- def rand_action (self , tensordict : Optional [ TensorDictBase ] = None ):
2876
+ def rand_action (self , tensordict : TensorDictBase | None = None ):
2881
2877
"""Performs a random action given the action_spec attribute.
2882
2878
2883
2879
Args:
@@ -2911,7 +2907,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2911
2907
tensordict .update (r )
2912
2908
return tensordict
2913
2909
2914
- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2910
+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2915
2911
"""Performs a random step in the environment given the action_spec attribute.
2916
2912
2917
2913
Args:
@@ -2947,15 +2943,15 @@ def _has_dynamic_specs(self) -> bool:
2947
2943
def rollout (
2948
2944
self ,
2949
2945
max_steps : int ,
2950
- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2951
- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2946
+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2947
+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
2952
2948
* ,
2953
2949
auto_reset : bool = True ,
2954
2950
auto_cast_to_device : bool = False ,
2955
2951
break_when_any_done : bool | None = None ,
2956
2952
break_when_all_done : bool | None = None ,
2957
2953
return_contiguous : bool | None = False ,
2958
- tensordict : Optional [ TensorDictBase ] = None ,
2954
+ tensordict : TensorDictBase | None = None ,
2959
2955
set_truncated : bool = False ,
2960
2956
out = None ,
2961
2957
trust_policy : bool = False ,
@@ -3485,7 +3481,7 @@ def _rollout_nonstop(
3485
3481
3486
3482
def step_and_maybe_reset (
3487
3483
self , tensordict : TensorDictBase
3488
- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3484
+ ) -> tuple [TensorDictBase , TensorDictBase ]:
3489
3485
"""Runs a step in the environment and (partially) resets it if needed.
3490
3486
3491
3487
Args:
@@ -3606,7 +3602,7 @@ def empty_cache(self):
3606
3602
3607
3603
@property
3608
3604
@_cache_value
3609
- def reset_keys (self ) -> List [NestedKey ]:
3605
+ def reset_keys (self ) -> list [NestedKey ]:
3610
3606
"""Returns a list of reset keys.
3611
3607
3612
3608
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3763,14 +3759,14 @@ class _EnvWrapper(EnvBase):
3763
3759
"""
3764
3760
3765
3761
git_url : str = ""
3766
- available_envs : Dict [str , Any ] = {}
3762
+ available_envs : dict [str , Any ] = {}
3767
3763
libname : str = ""
3768
3764
3769
3765
def __init__ (
3770
3766
self ,
3771
3767
* args ,
3772
3768
device : DEVICE_TYPING = None ,
3773
- batch_size : Optional [ torch .Size ] = None ,
3769
+ batch_size : torch .Size | None = None ,
3774
3770
allow_done_after_reset : bool = False ,
3775
3771
spec_locked : bool = True ,
3776
3772
** kwargs ,
@@ -3819,7 +3815,7 @@ def _sync_device(self):
3819
3815
return sync_func
3820
3816
3821
3817
@abc .abstractmethod
3822
- def _check_kwargs (self , kwargs : Dict ):
3818
+ def _check_kwargs (self , kwargs : dict ):
3823
3819
raise NotImplementedError
3824
3820
3825
3821
def __getattr__ (self , attr : str ) -> Any :
@@ -3845,7 +3841,7 @@ def __getattr__(self, attr: str) -> Any:
3845
3841
)
3846
3842
3847
3843
@abc .abstractmethod
3848
- def _init_env (self ) -> Optional [ int ] :
3844
+ def _init_env (self ) -> int | None :
3849
3845
"""Runs all the necessary steps such that the environment is ready to use.
3850
3846
3851
3847
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3859,7 +3855,7 @@ def _init_env(self) -> Optional[int]:
3859
3855
raise NotImplementedError
3860
3856
3861
3857
@abc .abstractmethod
3862
- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3858
+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
3863
3859
"""Creates an environment from the target library and stores it with the `_env` attribute.
3864
3860
3865
3861
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3868,7 +3864,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3868
3864
raise NotImplementedError
3869
3865
3870
3866
@abc .abstractmethod
3871
- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3867
+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
3872
3868
raise NotImplementedError
3873
3869
3874
3870
def close (self , * , raise_if_closed : bool = True ) -> None :
@@ -3882,7 +3878,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:
3882
3878
3883
3879
def make_tensordict (
3884
3880
env : _EnvWrapper ,
3885
- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3881
+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
3886
3882
) -> TensorDictBase :
3887
3883
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
3888
3884
0 commit comments