14
14
import re
15
15
import warnings
16
16
from enum import Enum
17
- from typing import Any , Dict , List
17
+ from typing import Any
18
18
19
19
import torch
20
20
@@ -329,9 +329,9 @@ def step_mdp(
329
329
exclude_reward : bool = True ,
330
330
exclude_done : bool = False ,
331
331
exclude_action : bool = True ,
332
- reward_keys : NestedKey | List [NestedKey ] = "reward" ,
333
- done_keys : NestedKey | List [NestedKey ] = "done" ,
334
- action_keys : NestedKey | List [NestedKey ] = "action" ,
332
+ reward_keys : NestedKey | list [NestedKey ] = "reward" ,
333
+ done_keys : NestedKey | list [NestedKey ] = "done" ,
334
+ action_keys : NestedKey | list [NestedKey ] = "action" ,
335
335
) -> TensorDictBase :
336
336
"""Creates a new tensordict that reflects a step in time of the input tensordict.
337
337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680
680
681
681
682
682
def check_env_specs (
683
- env ,
684
- return_contiguous = True ,
683
+ env : torchrl . envs . EnvBase , # noqa
684
+ return_contiguous : bool | None = None ,
685
685
check_dtype = True ,
686
686
seed : int | None = None ,
687
687
tensordict : TensorDictBase | None = None ,
@@ -700,7 +700,7 @@ def check_env_specs(
700
700
env (EnvBase): the env for which the specs have to be checked against data.
701
701
return_contiguous (bool, optional): if ``True``, the random rollout will be called with
702
702
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
703
- of inputs/outputs). Defaults to True .
703
+ of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs) .
704
704
check_dtype (bool, optional): if False, dtype checks will be skipped.
705
705
Defaults to True.
706
706
seed (int, optional): for reproducibility, a seed can be set.
@@ -718,6 +718,8 @@ def check_env_specs(
718
718
of an experiment and as such should be kept out of training scripts.
719
719
720
720
"""
721
+ if return_contiguous is None :
722
+ return_contiguous = not env ._has_dynamic_specs
721
723
if break_when_any_done == "both" :
722
724
check_env_specs (
723
725
env ,
@@ -746,7 +748,7 @@ def check_env_specs(
746
748
)
747
749
748
750
fake_tensordict = env .fake_tensordict ()
749
- if not env ._batch_locked and tensordict is not None :
751
+ if not env .batch_locked and tensordict is not None :
750
752
shape = torch .broadcast_shapes (fake_tensordict .shape , tensordict .shape )
751
753
fake_tensordict = fake_tensordict .expand (shape )
752
754
tensordict = tensordict .expand (shape )
@@ -786,10 +788,13 @@ def check_env_specs(
786
788
- List of keys present in fake but not in real: { fake_tensordict_keys - real_tensordict_keys } .
787
789
"""
788
790
)
789
- zeroing_err_msg = (
790
- "zeroing the two tensordicts did not make them identical. "
791
- f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
792
- )
791
+
792
+ def zeroing_err_msg ():
793
+ return (
794
+ "zeroing the two tensordicts did not make them identical. "
795
+ f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
796
+ )
797
+
793
798
from torchrl .envs .common import _has_dynamic_specs
794
799
795
800
if _has_dynamic_specs (env .specs ):
@@ -799,7 +804,7 @@ def check_env_specs(
799
804
):
800
805
fake = fake .apply (lambda x , y : x .expand_as (y ), real )
801
806
if (torch .zeros_like (real ) != torch .zeros_like (fake )).any ():
802
- raise AssertionError (zeroing_err_msg )
807
+ raise AssertionError (zeroing_err_msg () )
803
808
804
809
# Checks shapes and eventually dtypes of keys at all nesting levels
805
810
_per_level_env_check (fake , real , check_dtype = check_dtype )
@@ -809,7 +814,7 @@ def check_env_specs(
809
814
torch .zeros_like (fake_tensordict_select )
810
815
!= torch .zeros_like (real_tensordict_select )
811
816
).any ():
812
- raise AssertionError (zeroing_err_msg )
817
+ raise AssertionError (zeroing_err_msg () )
813
818
814
819
# Checks shapes and eventually dtypes of keys at all nesting levels
815
820
_per_level_env_check (
@@ -1030,14 +1035,14 @@ class MarlGroupMapType(Enum):
1030
1035
ALL_IN_ONE_GROUP = 1
1031
1036
ONE_GROUP_PER_AGENT = 2
1032
1037
1033
- def get_group_map (self , agent_names : List [str ]):
1038
+ def get_group_map (self , agent_names : list [str ]):
1034
1039
if self == MarlGroupMapType .ALL_IN_ONE_GROUP :
1035
1040
return {"agents" : agent_names }
1036
1041
elif self == MarlGroupMapType .ONE_GROUP_PER_AGENT :
1037
1042
return {agent_name : [agent_name ] for agent_name in agent_names }
1038
1043
1039
1044
1040
- def check_marl_grouping (group_map : Dict [str , List [str ]], agent_names : List [str ]):
1045
+ def check_marl_grouping (group_map : dict [str , list [str ]], agent_names : list [str ]):
1041
1046
"""Check MARL group map.
1042
1047
1043
1048
Performs checks on the group map of a marl environment to assess its validity.
@@ -1381,7 +1386,7 @@ def skim_through(td, reset=reset):
1381
1386
def _update_during_reset (
1382
1387
tensordict_reset : TensorDictBase ,
1383
1388
tensordict : TensorDictBase ,
1384
- reset_keys : List [NestedKey ],
1389
+ reset_keys : list [NestedKey ],
1385
1390
):
1386
1391
"""Updates the input tensordict with the reset data, based on the reset keys."""
1387
1392
if not reset_keys :
0 commit comments