9
9
import warnings
10
10
from copy import deepcopy
11
11
from functools import partial , wraps
12
- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple
12
+ from typing import Any , Callable , Iterator
13
13
14
14
import numpy as np
15
15
import torch
@@ -476,7 +476,7 @@ def __init__(
476
476
self ,
477
477
* ,
478
478
device : DEVICE_TYPING = None ,
479
- batch_size : Optional [ torch .Size ] = None ,
479
+ batch_size : torch .Size | None = None ,
480
480
run_type_checks : bool = False ,
481
481
allow_done_after_reset : bool = False ,
482
482
spec_locked : bool = True ,
@@ -587,10 +587,10 @@ def auto_specs_(
587
587
policy : Callable [[TensorDictBase ], TensorDictBase ],
588
588
* ,
589
589
tensordict : TensorDictBase | None = None ,
590
- action_key : NestedKey | List [NestedKey ] = "action" ,
591
- done_key : NestedKey | List [NestedKey ] | None = None ,
592
- observation_key : NestedKey | List [NestedKey ] = "observation" ,
593
- reward_key : NestedKey | List [NestedKey ] = "reward" ,
590
+ action_key : NestedKey | list [NestedKey ] = "action" ,
591
+ done_key : NestedKey | list [NestedKey ] | None = None ,
592
+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
593
+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
594
594
):
595
595
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596
596
@@ -692,7 +692,7 @@ def auto_specs_(
692
692
if full_action_spec is not None :
693
693
self .full_action_spec = full_action_spec
694
694
if full_done_spec is not None :
695
- self .full_done_specs = full_done_spec
695
+ self .full_done_spec = full_done_spec
696
696
if full_observation_spec is not None :
697
697
self .full_observation_spec = full_observation_spec
698
698
if full_reward_spec is not None :
@@ -704,8 +704,7 @@ def auto_specs_(
704
704
705
705
@wraps (check_env_specs_func )
706
706
def check_env_specs (self , * args , ** kwargs ):
707
- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
708
- kwargs ["return_contiguous" ] = return_contiguous
707
+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
709
708
return check_env_specs_func (self , * args , ** kwargs )
710
709
711
710
check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850
849
851
850
def append_transform (
852
851
self ,
853
- transform : "Transform" # noqa: F821
854
- | Callable [[TensorDictBase ], TensorDictBase ],
852
+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
855
853
) -> EnvBase :
856
854
"""Returns a transformed environment where the callable/transform passed is applied.
857
855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995
993
996
994
@property
997
995
@_cache_value
998
- def action_keys (self ) -> List [NestedKey ]:
996
+ def action_keys (self ) -> list [NestedKey ]:
999
997
"""The action keys of an environment.
1000
998
1001
999
By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
1008
1006
1009
1007
@property
1010
1008
@_cache_value
1011
- def state_keys (self ) -> List [NestedKey ]:
1009
+ def state_keys (self ) -> list [NestedKey ]:
1012
1010
"""The state keys of an environment.
1013
1011
1014
1012
By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
1205
1203
# Reward spec
1206
1204
@property
1207
1205
@_cache_value
1208
- def reward_keys (self ) -> List [NestedKey ]:
1206
+ def reward_keys (self ) -> list [NestedKey ]:
1209
1207
"""The reward keys of an environment.
1210
1208
1211
1209
By default, there will only be one key named "reward".
@@ -1215,6 +1213,20 @@ def reward_keys(self) -> List[NestedKey]:
1215
1213
reward_keys = sorted (self .full_reward_spec .keys (True , True ), key = _repr_by_depth )
1216
1214
return reward_keys
1217
1215
1216
+ @property
1217
+ @_cache_value
1218
+ def observation_keys (self ) -> list [NestedKey ]:
1219
+ """The observation keys of an environment.
1220
+
1221
+ By default, there will only be one key named "observation".
1222
+
1223
+ Keys are sorted by depth in the data tree.
1224
+ """
1225
+ observation_keys = sorted (
1226
+ self .full_observation_spec .keys (True , True ), key = _repr_by_depth
1227
+ )
1228
+ return observation_keys
1229
+
1218
1230
@property
1219
1231
def reward_key (self ):
1220
1232
"""The reward key of an environment.
@@ -1402,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
1402
1414
# done spec
1403
1415
@property
1404
1416
@_cache_value
1405
- def done_keys (self ) -> List [NestedKey ]:
1417
+ def done_keys (self ) -> list [NestedKey ]:
1406
1418
"""The done keys of an environment.
1407
1419
1408
1420
By default, there will only be one key named "done".
@@ -2190,8 +2202,8 @@ def register_gym(
2190
2202
id : str ,
2191
2203
* ,
2192
2204
entry_point : Callable | None = None ,
2193
- transform : " Transform" | None = None , # noqa: F821
2194
- info_keys : List [NestedKey ] | None = None ,
2205
+ transform : Transform | None = None , # noqa: F821
2206
+ info_keys : list [NestedKey ] | None = None ,
2195
2207
backend : str = None ,
2196
2208
to_numpy : bool = False ,
2197
2209
reward_threshold : float | None = None ,
@@ -2380,8 +2392,8 @@ def _register_gym(
2380
2392
cls ,
2381
2393
id ,
2382
2394
entry_point : Callable | None = None ,
2383
- transform : " Transform" | None = None , # noqa: F821
2384
- info_keys : List [NestedKey ] | None = None ,
2395
+ transform : Transform | None = None , # noqa: F821
2396
+ info_keys : list [NestedKey ] | None = None ,
2385
2397
to_numpy : bool = False ,
2386
2398
reward_threshold : float | None = None ,
2387
2399
nondeterministic : bool = False ,
@@ -2422,8 +2434,8 @@ def _register_gym( # noqa: F811
2422
2434
cls ,
2423
2435
id ,
2424
2436
entry_point : Callable | None = None ,
2425
- transform : " Transform" | None = None , # noqa: F821
2426
- info_keys : List [NestedKey ] | None = None ,
2437
+ transform : Transform | None = None , # noqa: F821
2438
+ info_keys : list [NestedKey ] | None = None ,
2427
2439
to_numpy : bool = False ,
2428
2440
reward_threshold : float | None = None ,
2429
2441
nondeterministic : bool = False ,
@@ -2470,8 +2482,8 @@ def _register_gym( # noqa: F811
2470
2482
cls ,
2471
2483
id ,
2472
2484
entry_point : Callable | None = None ,
2473
- transform : " Transform" | None = None , # noqa: F821
2474
- info_keys : List [NestedKey ] | None = None ,
2485
+ transform : Transform | None = None , # noqa: F821
2486
+ info_keys : list [NestedKey ] | None = None ,
2475
2487
to_numpy : bool = False ,
2476
2488
reward_threshold : float | None = None ,
2477
2489
nondeterministic : bool = False ,
@@ -2523,8 +2535,8 @@ def _register_gym( # noqa: F811
2523
2535
cls ,
2524
2536
id ,
2525
2537
entry_point : Callable | None = None ,
2526
- transform : " Transform" | None = None , # noqa: F821
2527
- info_keys : List [NestedKey ] | None = None ,
2538
+ transform : Transform | None = None , # noqa: F821
2539
+ info_keys : list [NestedKey ] | None = None ,
2528
2540
to_numpy : bool = False ,
2529
2541
reward_threshold : float | None = None ,
2530
2542
nondeterministic : bool = False ,
@@ -2579,8 +2591,8 @@ def _register_gym( # noqa: F811
2579
2591
cls ,
2580
2592
id ,
2581
2593
entry_point : Callable | None = None ,
2582
- transform : " Transform" | None = None , # noqa: F821
2583
- info_keys : List [NestedKey ] | None = None ,
2594
+ transform : Transform | None = None , # noqa: F821
2595
+ info_keys : list [NestedKey ] | None = None ,
2584
2596
to_numpy : bool = False ,
2585
2597
reward_threshold : float | None = None ,
2586
2598
nondeterministic : bool = False ,
@@ -2637,8 +2649,8 @@ def _register_gym( # noqa: F811
2637
2649
cls ,
2638
2650
id ,
2639
2651
entry_point : Callable | None = None ,
2640
- transform : " Transform" | None = None , # noqa: F821
2641
- info_keys : List [NestedKey ] | None = None ,
2652
+ transform : Transform | None = None , # noqa: F821
2653
+ info_keys : list [NestedKey ] | None = None ,
2642
2654
to_numpy : bool = False ,
2643
2655
reward_threshold : float | None = None ,
2644
2656
nondeterministic : bool = False ,
@@ -2695,7 +2707,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2695
2707
2696
2708
def reset (
2697
2709
self ,
2698
- tensordict : Optional [ TensorDictBase ] = None ,
2710
+ tensordict : TensorDictBase | None = None ,
2699
2711
** kwargs ,
2700
2712
) -> TensorDictBase :
2701
2713
"""Resets the environment.
@@ -2804,8 +2816,8 @@ def numel(self) -> int:
2804
2816
return prod (self .batch_size )
2805
2817
2806
2818
def set_seed (
2807
- self , seed : Optional [ int ] = None , static_seed : bool = False
2808
- ) -> Optional [ int ] :
2819
+ self , seed : int | None = None , static_seed : bool = False
2820
+ ) -> int | None :
2809
2821
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
2810
2822
2811
2823
Args:
@@ -2826,7 +2838,7 @@ def set_seed(
2826
2838
return seed
2827
2839
2828
2840
@abc .abstractmethod
2829
- def _set_seed (self , seed : Optional [ int ] ):
2841
+ def _set_seed (self , seed : int | None ):
2830
2842
raise NotImplementedError
2831
2843
2832
2844
def set_state (self ):
@@ -2841,7 +2853,26 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
2841
2853
f"got { tensordict .batch_size } and { self .batch_size } "
2842
2854
)
2843
2855
2844
- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
2856
+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2857
+ """Generates all possible actions from the action spec.
2858
+
2859
+ This only works in environments with fully discrete actions.
2860
+
2861
+ Args:
2862
+ tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
2863
+ is called with this tensordict.
2864
+
2865
+ Returns:
2866
+ a tensordict object with the "action" entry updated with a batch of
2867
+ all possible actions. The actions are stacked together in the
2868
+ leading dimension.
2869
+ """
2870
+ if tensordict is not None :
2871
+ self .reset (tensordict )
2872
+
2873
+ return self .full_action_spec .enumerate (use_mask = True )
2874
+
2875
+ def rand_action (self , tensordict : TensorDictBase | None = None ):
2845
2876
"""Performs a random action given the action_spec attribute.
2846
2877
2847
2878
Args:
@@ -2875,7 +2906,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
2875
2906
tensordict .update (r )
2876
2907
return tensordict
2877
2908
2878
- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2909
+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
2879
2910
"""Performs a random step in the environment given the action_spec attribute.
2880
2911
2881
2912
Args:
@@ -2911,15 +2942,15 @@ def _has_dynamic_specs(self) -> bool:
2911
2942
def rollout (
2912
2943
self ,
2913
2944
max_steps : int ,
2914
- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2915
- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2945
+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2946
+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
2916
2947
* ,
2917
2948
auto_reset : bool = True ,
2918
2949
auto_cast_to_device : bool = False ,
2919
2950
break_when_any_done : bool | None = None ,
2920
2951
break_when_all_done : bool | None = None ,
2921
2952
return_contiguous : bool | None = False ,
2922
- tensordict : Optional [ TensorDictBase ] = None ,
2953
+ tensordict : TensorDictBase | None = None ,
2923
2954
set_truncated : bool = False ,
2924
2955
out = None ,
2925
2956
trust_policy : bool = False ,
@@ -3441,7 +3472,7 @@ def _rollout_nonstop(
3441
3472
3442
3473
def step_and_maybe_reset (
3443
3474
self , tensordict : TensorDictBase
3444
- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3475
+ ) -> tuple [TensorDictBase , TensorDictBase ]:
3445
3476
"""Runs a step in the environment and (partially) resets it if needed.
3446
3477
3447
3478
Args:
@@ -3549,7 +3580,7 @@ def empty_cache(self):
3549
3580
3550
3581
@property
3551
3582
@_cache_value
3552
- def reset_keys (self ) -> List [NestedKey ]:
3583
+ def reset_keys (self ) -> list [NestedKey ]:
3553
3584
"""Returns a list of reset keys.
3554
3585
3555
3586
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3706,14 +3737,14 @@ class _EnvWrapper(EnvBase):
3706
3737
"""
3707
3738
3708
3739
git_url : str = ""
3709
- available_envs : Dict [str , Any ] = {}
3740
+ available_envs : dict [str , Any ] = {}
3710
3741
libname : str = ""
3711
3742
3712
3743
def __init__ (
3713
3744
self ,
3714
3745
* args ,
3715
3746
device : DEVICE_TYPING = None ,
3716
- batch_size : Optional [ torch .Size ] = None ,
3747
+ batch_size : torch .Size | None = None ,
3717
3748
allow_done_after_reset : bool = False ,
3718
3749
spec_locked : bool = True ,
3719
3750
** kwargs ,
@@ -3762,7 +3793,7 @@ def _sync_device(self):
3762
3793
return sync_func
3763
3794
3764
3795
@abc .abstractmethod
3765
- def _check_kwargs (self , kwargs : Dict ):
3796
+ def _check_kwargs (self , kwargs : dict ):
3766
3797
raise NotImplementedError
3767
3798
3768
3799
def __getattr__ (self , attr : str ) -> Any :
@@ -3788,7 +3819,7 @@ def __getattr__(self, attr: str) -> Any:
3788
3819
)
3789
3820
3790
3821
@abc .abstractmethod
3791
- def _init_env (self ) -> Optional [ int ] :
3822
+ def _init_env (self ) -> int | None :
3792
3823
"""Runs all the necessary steps such that the environment is ready to use.
3793
3824
3794
3825
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3802,7 +3833,7 @@ def _init_env(self) -> Optional[int]:
3802
3833
raise NotImplementedError
3803
3834
3804
3835
@abc .abstractmethod
3805
- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3836
+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
3806
3837
"""Creates an environment from the target library and stores it with the `_env` attribute.
3807
3838
3808
3839
When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3811,7 +3842,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
3811
3842
raise NotImplementedError
3812
3843
3813
3844
@abc .abstractmethod
3814
- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3845
+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
3815
3846
raise NotImplementedError
3816
3847
3817
3848
def close (self ) -> None :
@@ -3825,7 +3856,7 @@ def close(self) -> None:
3825
3856
3826
3857
def make_tensordict (
3827
3858
env : _EnvWrapper ,
3828
- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3859
+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
3829
3860
) -> TensorDictBase :
3830
3861
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
3831
3862
0 commit comments