Skip to content

Commit a2f3a68

Browse files
committed
[BugFix] Fix batch_locked check in check_env_specs + error message callable
ghstack-source-id: c722b16 Pull Request resolved: #2817 (cherry picked from commit 9c98b82)
1 parent 90f37ed commit a2f3a68

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

torchrl/envs/utils.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import re
1515
import warnings
1616
from enum import Enum
17-
from typing import Any, Dict, List
17+
from typing import Any
1818

1919
import torch
2020

@@ -329,9 +329,9 @@ def step_mdp(
329329
exclude_reward: bool = True,
330330
exclude_done: bool = False,
331331
exclude_action: bool = True,
332-
reward_keys: NestedKey | List[NestedKey] = "reward",
333-
done_keys: NestedKey | List[NestedKey] = "done",
334-
action_keys: NestedKey | List[NestedKey] = "action",
332+
reward_keys: NestedKey | list[NestedKey] = "reward",
333+
done_keys: NestedKey | list[NestedKey] = "done",
334+
action_keys: NestedKey | list[NestedKey] = "action",
335335
) -> TensorDictBase:
336336
"""Creates a new tensordict that reflects a step in time of the input tensordict.
337337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680680

681681

682682
def check_env_specs(
683-
env,
684-
return_contiguous=True,
683+
env: torchrl.envs.EnvBase, # noqa
684+
return_contiguous: bool | None = None,
685685
check_dtype=True,
686686
seed: int | None = None,
687687
tensordict: TensorDictBase | None = None,
@@ -700,7 +700,7 @@ def check_env_specs(
700700
env (EnvBase): the env for which the specs have to be checked against data.
701701
return_contiguous (bool, optional): if ``True``, the random rollout will be called with
702702
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
703-
of inputs/outputs). Defaults to True.
703+
of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs).
704704
check_dtype (bool, optional): if False, dtype checks will be skipped.
705705
Defaults to True.
706706
seed (int, optional): for reproducibility, a seed can be set.
@@ -718,6 +718,8 @@ def check_env_specs(
718718
of an experiment and as such should be kept out of training scripts.
719719
720720
"""
721+
if return_contiguous is None:
722+
return_contiguous = not env._has_dynamic_specs
721723
if break_when_any_done == "both":
722724
check_env_specs(
723725
env,
@@ -746,7 +748,7 @@ def check_env_specs(
746748
)
747749

748750
fake_tensordict = env.fake_tensordict()
749-
if not env._batch_locked and tensordict is not None:
751+
if not env.batch_locked and tensordict is not None:
750752
shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape)
751753
fake_tensordict = fake_tensordict.expand(shape)
752754
tensordict = tensordict.expand(shape)
@@ -786,10 +788,13 @@ def check_env_specs(
786788
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
787789
"""
788790
)
789-
zeroing_err_msg = (
790-
"zeroing the two tensordicts did not make them identical. "
791-
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
792-
)
791+
792+
def zeroing_err_msg():
793+
return (
794+
"zeroing the two tensordicts did not make them identical. "
795+
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
796+
)
797+
793798
from torchrl.envs.common import _has_dynamic_specs
794799

795800
if _has_dynamic_specs(env.specs):
@@ -799,7 +804,7 @@ def check_env_specs(
799804
):
800805
fake = fake.apply(lambda x, y: x.expand_as(y), real)
801806
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
802-
raise AssertionError(zeroing_err_msg)
807+
raise AssertionError(zeroing_err_msg())
803808

804809
# Checks shapes and eventually dtypes of keys at all nesting levels
805810
_per_level_env_check(fake, real, check_dtype=check_dtype)
@@ -809,7 +814,7 @@ def check_env_specs(
809814
torch.zeros_like(fake_tensordict_select)
810815
!= torch.zeros_like(real_tensordict_select)
811816
).any():
812-
raise AssertionError(zeroing_err_msg)
817+
raise AssertionError(zeroing_err_msg())
813818

814819
# Checks shapes and eventually dtypes of keys at all nesting levels
815820
_per_level_env_check(
@@ -1030,14 +1035,14 @@ class MarlGroupMapType(Enum):
10301035
ALL_IN_ONE_GROUP = 1
10311036
ONE_GROUP_PER_AGENT = 2
10321037

1033-
def get_group_map(self, agent_names: List[str]):
1038+
def get_group_map(self, agent_names: list[str]):
10341039
if self == MarlGroupMapType.ALL_IN_ONE_GROUP:
10351040
return {"agents": agent_names}
10361041
elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT:
10371042
return {agent_name: [agent_name] for agent_name in agent_names}
10381043

10391044

1040-
def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]):
1045+
def check_marl_grouping(group_map: dict[str, list[str]], agent_names: list[str]):
10411046
"""Check MARL group map.
10421047
10431048
Performs checks on the group map of a marl environment to assess its validity.
@@ -1381,7 +1386,7 @@ def skim_through(td, reset=reset):
13811386
def _update_during_reset(
13821387
tensordict_reset: TensorDictBase,
13831388
tensordict: TensorDictBase,
1384-
reset_keys: List[NestedKey],
1389+
reset_keys: list[NestedKey],
13851390
):
13861391
"""Updates the input tensordict with the reset data, based on the reset keys."""
13871392
if not reset_keys:

0 commit comments

Comments
 (0)