Skip to content

Commit c797855

Browse files
committed
[BugFix] Fix batch_locked check in check_env_specs + error message callable
ghstack-source-id: c722b16 Pull Request resolved: #2817 (cherry picked from commit 9c98b82)
1 parent d556726 commit c797855

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

torchrl/envs/utils.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import re
1515
import warnings
1616
from enum import Enum
17-
from typing import Any, Dict, List
17+
from typing import Any
1818

1919
import torch
2020

@@ -329,9 +329,9 @@ def step_mdp(
329329
exclude_reward: bool = True,
330330
exclude_done: bool = False,
331331
exclude_action: bool = True,
332-
reward_keys: NestedKey | List[NestedKey] = "reward",
333-
done_keys: NestedKey | List[NestedKey] = "done",
334-
action_keys: NestedKey | List[NestedKey] = "action",
332+
reward_keys: NestedKey | list[NestedKey] = "reward",
333+
done_keys: NestedKey | list[NestedKey] = "done",
334+
action_keys: NestedKey | list[NestedKey] = "action",
335335
) -> TensorDictBase:
336336
"""Creates a new tensordict that reflects a step in time of the input tensordict.
337337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680680

681681

682682
def check_env_specs(
683-
env,
684-
return_contiguous=True,
683+
env: torchrl.envs.EnvBase, # noqa
684+
return_contiguous: bool | None = None,
685685
check_dtype=True,
686686
seed: int | None = None,
687687
tensordict: TensorDictBase | None = None,
@@ -699,7 +699,7 @@ def check_env_specs(
699699
env (EnvBase): the env for which the specs have to be checked against data.
700700
return_contiguous (bool, optional): if ``True``, the random rollout will be called with
701701
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
702-
of inputs/outputs). Defaults to True.
702+
of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs).
703703
check_dtype (bool, optional): if False, dtype checks will be skipped.
704704
Defaults to True.
705705
seed (int, optional): for reproducibility, a seed can be set.
@@ -715,6 +715,8 @@ def check_env_specs(
715715
of an experiment and as such should be kept out of training scripts.
716716
717717
"""
718+
if return_contiguous is None:
719+
return_contiguous = not env._has_dynamic_specs
718720
if seed is not None:
719721
device = (
720722
env.device if env.device is not None and env.device.type == "cuda" else None
@@ -726,7 +728,7 @@ def check_env_specs(
726728
)
727729

728730
fake_tensordict = env.fake_tensordict()
729-
if not env._batch_locked and tensordict is not None:
731+
if not env.batch_locked and tensordict is not None:
730732
shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape)
731733
fake_tensordict = fake_tensordict.expand(shape)
732734
tensordict = tensordict.expand(shape)
@@ -765,10 +767,13 @@ def check_env_specs(
765767
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
766768
"""
767769
)
768-
zeroing_err_msg = (
769-
"zeroing the two tensordicts did not make them identical. "
770-
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
771-
)
770+
771+
def zeroing_err_msg():
772+
return (
773+
"zeroing the two tensordicts did not make them identical. "
774+
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
775+
)
776+
772777
from torchrl.envs.common import _has_dynamic_specs
773778

774779
if _has_dynamic_specs(env.specs):
@@ -778,7 +783,7 @@ def check_env_specs(
778783
):
779784
fake = fake.apply(lambda x, y: x.expand_as(y), real)
780785
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
781-
raise AssertionError(zeroing_err_msg)
786+
raise AssertionError(zeroing_err_msg())
782787

783788
# Checks shapes and eventually dtypes of keys at all nesting levels
784789
_per_level_env_check(fake, real, check_dtype=check_dtype)
@@ -788,7 +793,7 @@ def check_env_specs(
788793
torch.zeros_like(fake_tensordict_select)
789794
!= torch.zeros_like(real_tensordict_select)
790795
).any():
791-
raise AssertionError(zeroing_err_msg)
796+
raise AssertionError(zeroing_err_msg())
792797

793798
# Checks shapes and eventually dtypes of keys at all nesting levels
794799
_per_level_env_check(
@@ -1009,14 +1014,14 @@ class MarlGroupMapType(Enum):
10091014
ALL_IN_ONE_GROUP = 1
10101015
ONE_GROUP_PER_AGENT = 2
10111016

1012-
def get_group_map(self, agent_names: List[str]):
1017+
def get_group_map(self, agent_names: list[str]):
10131018
if self == MarlGroupMapType.ALL_IN_ONE_GROUP:
10141019
return {"agents": agent_names}
10151020
elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT:
10161021
return {agent_name: [agent_name] for agent_name in agent_names}
10171022

10181023

1019-
def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]):
1024+
def check_marl_grouping(group_map: dict[str, list[str]], agent_names: list[str]):
10201025
"""Check MARL group map.
10211026
10221027
Performs checks on the group map of a marl environment to assess its validity.
@@ -1360,7 +1365,7 @@ def skim_through(td, reset=reset):
13601365
def _update_during_reset(
13611366
tensordict_reset: TensorDictBase,
13621367
tensordict: TensorDictBase,
1363-
reset_keys: List[NestedKey],
1368+
reset_keys: list[NestedKey],
13641369
):
13651370
"""Updates the input tensordict with the reset data, based on the reset keys."""
13661371
if not reset_keys:

0 commit comments

Comments
 (0)