14
14
import re
15
15
import warnings
16
16
from enum import Enum
17
- from typing import Any , Dict , List
17
+ from typing import Any
18
18
19
19
import torch
20
20
@@ -329,9 +329,9 @@ def step_mdp(
329
329
exclude_reward : bool = True ,
330
330
exclude_done : bool = False ,
331
331
exclude_action : bool = True ,
332
- reward_keys : NestedKey | List [NestedKey ] = "reward" ,
333
- done_keys : NestedKey | List [NestedKey ] = "done" ,
334
- action_keys : NestedKey | List [NestedKey ] = "action" ,
332
+ reward_keys : NestedKey | list [NestedKey ] = "reward" ,
333
+ done_keys : NestedKey | list [NestedKey ] = "done" ,
334
+ action_keys : NestedKey | list [NestedKey ] = "action" ,
335
335
) -> TensorDictBase :
336
336
"""Creates a new tensordict that reflects a step in time of the input tensordict.
337
337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680
680
681
681
682
682
def check_env_specs (
683
- env ,
684
- return_contiguous = True ,
683
+ env : torchrl . envs . EnvBase , # noqa
684
+ return_contiguous : bool | None = None ,
685
685
check_dtype = True ,
686
686
seed : int | None = None ,
687
687
tensordict : TensorDictBase | None = None ,
@@ -699,7 +699,7 @@ def check_env_specs(
699
699
env (EnvBase): the env for which the specs have to be checked against data.
700
700
return_contiguous (bool, optional): if ``True``, the random rollout will be called with
701
701
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
702
- of inputs/outputs). Defaults to True .
702
+ of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs) .
703
703
check_dtype (bool, optional): if False, dtype checks will be skipped.
704
704
Defaults to True.
705
705
seed (int, optional): for reproducibility, a seed can be set.
@@ -715,6 +715,8 @@ def check_env_specs(
715
715
of an experiment and as such should be kept out of training scripts.
716
716
717
717
"""
718
+ if return_contiguous is None :
719
+ return_contiguous = not env ._has_dynamic_specs
718
720
if seed is not None :
719
721
device = (
720
722
env .device if env .device is not None and env .device .type == "cuda" else None
@@ -726,7 +728,7 @@ def check_env_specs(
726
728
)
727
729
728
730
fake_tensordict = env .fake_tensordict ()
729
- if not env ._batch_locked and tensordict is not None :
731
+ if not env .batch_locked and tensordict is not None :
730
732
shape = torch .broadcast_shapes (fake_tensordict .shape , tensordict .shape )
731
733
fake_tensordict = fake_tensordict .expand (shape )
732
734
tensordict = tensordict .expand (shape )
@@ -765,10 +767,13 @@ def check_env_specs(
765
767
- List of keys present in fake but not in real: { fake_tensordict_keys - real_tensordict_keys } .
766
768
"""
767
769
)
768
- zeroing_err_msg = (
769
- "zeroing the two tensordicts did not make them identical. "
770
- f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
771
- )
770
+
771
+ def zeroing_err_msg ():
772
+ return (
773
+ "zeroing the two tensordicts did not make them identical. "
774
+ f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
775
+ )
776
+
772
777
from torchrl .envs .common import _has_dynamic_specs
773
778
774
779
if _has_dynamic_specs (env .specs ):
@@ -778,7 +783,7 @@ def check_env_specs(
778
783
):
779
784
fake = fake .apply (lambda x , y : x .expand_as (y ), real )
780
785
if (torch .zeros_like (real ) != torch .zeros_like (fake )).any ():
781
- raise AssertionError (zeroing_err_msg )
786
+ raise AssertionError (zeroing_err_msg () )
782
787
783
788
# Checks shapes and eventually dtypes of keys at all nesting levels
784
789
_per_level_env_check (fake , real , check_dtype = check_dtype )
@@ -788,7 +793,7 @@ def check_env_specs(
788
793
torch .zeros_like (fake_tensordict_select )
789
794
!= torch .zeros_like (real_tensordict_select )
790
795
).any ():
791
- raise AssertionError (zeroing_err_msg )
796
+ raise AssertionError (zeroing_err_msg () )
792
797
793
798
# Checks shapes and eventually dtypes of keys at all nesting levels
794
799
_per_level_env_check (
@@ -1009,14 +1014,14 @@ class MarlGroupMapType(Enum):
1009
1014
ALL_IN_ONE_GROUP = 1
1010
1015
ONE_GROUP_PER_AGENT = 2
1011
1016
1012
- def get_group_map (self , agent_names : List [str ]):
1017
+ def get_group_map (self , agent_names : list [str ]):
1013
1018
if self == MarlGroupMapType .ALL_IN_ONE_GROUP :
1014
1019
return {"agents" : agent_names }
1015
1020
elif self == MarlGroupMapType .ONE_GROUP_PER_AGENT :
1016
1021
return {agent_name : [agent_name ] for agent_name in agent_names }
1017
1022
1018
1023
1019
- def check_marl_grouping (group_map : Dict [str , List [str ]], agent_names : List [str ]):
1024
+ def check_marl_grouping (group_map : dict [str , list [str ]], agent_names : list [str ]):
1020
1025
"""Check MARL group map.
1021
1026
1022
1027
Performs checks on the group map of a marl environment to assess its validity.
@@ -1360,7 +1365,7 @@ def skim_through(td, reset=reset):
1360
1365
def _update_during_reset (
1361
1366
tensordict_reset : TensorDictBase ,
1362
1367
tensordict : TensorDictBase ,
1363
- reset_keys : List [NestedKey ],
1368
+ reset_keys : list [NestedKey ],
1364
1369
):
1365
1370
"""Updates the input tensordict with the reset data, based on the reset keys."""
1366
1371
if not reset_keys :
0 commit comments