Skip to content

Commit df6871a

Browse files
mikaylagawareckivmoens
authored andcommitted
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: a56a017 Pull Request resolved: #2711
1 parent 1813e8e commit df6871a

File tree

5 files changed

+451
-7
lines changed

5 files changed

+451
-7
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,7 @@ to be able to create this other composition:
11121112
CenterCrop
11131113
ClipTransform
11141114
Compose
1115+
ConditionalPolicySwitch
11151116
ConditionalSkip
11161117
Crop
11171118
DataLoadingPrimer

test/test_transforms.py

+203-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import tensordict.tensordict
2626
import torch
27+
2728
from tensordict import (
2829
assert_close,
2930
LazyStackedTensorDict,
@@ -33,7 +34,7 @@
3334
TensorDictBase,
3435
unravel_key,
3536
)
36-
from tensordict.nn import TensorDictSequential
37+
from tensordict.nn import TensorDictSequential, WrapModule
3738
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
3839
from torch import multiprocessing as mp, nn, Tensor
3940
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
@@ -62,6 +63,7 @@
6263
CenterCrop,
6364
ClipTransform,
6465
Compose,
66+
ConditionalPolicySwitch,
6567
ConditionalSkip,
6668
Crop,
6769
DeviceCastTransform,
@@ -14526,6 +14528,206 @@ def test_can_init_with_fps(self):
1452614528
assert recorder is not None
1452714529

1452814530

14531+
class TestConditionalPolicySwitch(TransformBase):
14532+
def test_single_trans_env_check(self):
14533+
base_env = CountingEnv(max_steps=15)
14534+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
14535+
# Player 0
14536+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
14537+
policy_even = lambda td: td.set("action", env.action_spec.one())
14538+
transforms = Compose(
14539+
StepCounter(),
14540+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14541+
)
14542+
env = base_env.append_transform(transforms)
14543+
env.check_env_specs()
14544+
14545+
def _create_policy_odd(self, base_env):
14546+
return WrapModule(
14547+
lambda td, base_env=base_env: td.set(
14548+
"action", base_env.action_spec_unbatched.zero(td.shape)
14549+
),
14550+
out_keys=["action"],
14551+
)
14552+
14553+
def _create_policy_even(self, base_env):
14554+
return WrapModule(
14555+
lambda td, base_env=base_env: td.set(
14556+
"action", base_env.action_spec_unbatched.one(td.shape)
14557+
),
14558+
out_keys=["action"],
14559+
)
14560+
14561+
def _create_transforms(self, condition, policy_even):
14562+
return Compose(
14563+
StepCounter(),
14564+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14565+
)
14566+
14567+
def _make_env(self, max_count, env_cls):
14568+
torch.manual_seed(0)
14569+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14570+
base_env = env_cls(max_steps=max_count)
14571+
policy_even = self._create_policy_even(base_env)
14572+
transforms = self._create_transforms(condition, policy_even)
14573+
return base_env.append_transform(transforms)
14574+
14575+
def _test_env(self, env, policy_odd):
14576+
env.check_env_specs()
14577+
env.set_seed(0)
14578+
r = env.rollout(100, policy_odd, break_when_any_done=False)
14579+
# Check results are independent: one reset / step in one env should not impact results in another
14580+
r0, r1, r2 = r.unbind(0)
14581+
r0_split = r0.split(6)
14582+
assert all((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])
14583+
r1_split = r1.split(7)
14584+
assert all((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])
14585+
r2_split = r2.split(8)
14586+
assert all((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])
14587+
14588+
def test_trans_serial_env_check(self):
14589+
torch.manual_seed(0)
14590+
base_env = SerialEnv(
14591+
3,
14592+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14593+
batch_locked=False,
14594+
)
14595+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14596+
policy_odd = self._create_policy_odd(base_env)
14597+
policy_even = self._create_policy_even(base_env)
14598+
transforms = self._create_transforms(condition, policy_even)
14599+
env = base_env.append_transform(transforms)
14600+
self._test_env(env, policy_odd)
14601+
14602+
def test_trans_parallel_env_check(self):
14603+
torch.manual_seed(0)
14604+
base_env = ParallelEnv(
14605+
3,
14606+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14607+
batch_locked=False,
14608+
mp_start_method=mp_ctx,
14609+
)
14610+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14611+
policy_odd = self._create_policy_odd(base_env)
14612+
policy_even = self._create_policy_even(base_env)
14613+
transforms = self._create_transforms(condition, policy_even)
14614+
env = base_env.append_transform(transforms)
14615+
self._test_env(env, policy_odd)
14616+
14617+
def test_serial_trans_env_check(self):
14618+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14619+
policy_odd = self._create_policy_odd(CountingEnv())
14620+
14621+
def make_env(max_count):
14622+
return partial(self._make_env, max_count, CountingEnv)
14623+
14624+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
14625+
self._test_env(env, policy_odd)
14626+
14627+
def test_parallel_trans_env_check(self):
14628+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14629+
policy_odd = self._create_policy_odd(CountingEnv())
14630+
14631+
def make_env(max_count):
14632+
return partial(self._make_env, max_count, CountingEnv)
14633+
14634+
env = ParallelEnv(
14635+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
14636+
)
14637+
self._test_env(env, policy_odd)
14638+
14639+
def test_transform_no_env(self):
14640+
policy_odd = lambda td: td
14641+
policy_even = lambda td: td
14642+
condition = lambda td: True
14643+
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
14644+
with pytest.raises(
14645+
RuntimeError,
14646+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14647+
):
14648+
transforms(TensorDict())
14649+
14650+
def test_transform_compose(self):
14651+
policy_odd = lambda td: td
14652+
policy_even = lambda td: td
14653+
condition = lambda td: True
14654+
transforms = Compose(
14655+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14656+
)
14657+
with pytest.raises(
14658+
RuntimeError,
14659+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14660+
):
14661+
transforms(TensorDict())
14662+
14663+
def test_transform_env(self):
14664+
base_env = CountingEnv(max_steps=15)
14665+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
14666+
# Player 0
14667+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
14668+
policy_even = lambda td: td.set("action", env.action_spec.one())
14669+
transforms = Compose(
14670+
StepCounter(),
14671+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14672+
)
14673+
env = base_env.append_transform(transforms)
14674+
env.check_env_specs()
14675+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
14676+
assert r.shape[0] == 15
14677+
assert (r["action"] == 0).all()
14678+
assert (
14679+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
14680+
).all()
14681+
assert r["next", "done"].any()
14682+
14683+
# Player 1
14684+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
14685+
transforms = Compose(
14686+
StepCounter(),
14687+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
14688+
)
14689+
env = base_env.append_transform(transforms)
14690+
r = env.rollout(1000, policy_even, break_when_all_done=True)
14691+
assert r.shape[0] == 16
14692+
assert (r["action"] == 1).all()
14693+
assert (
14694+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
14695+
).all()
14696+
assert r["next", "done"].any()
14697+
14698+
def test_transform_model(self):
14699+
policy_odd = lambda td: td
14700+
policy_even = lambda td: td
14701+
condition = lambda td: True
14702+
transforms = nn.Sequential(
14703+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14704+
)
14705+
with pytest.raises(
14706+
RuntimeError,
14707+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14708+
):
14709+
transforms(TensorDict())
14710+
14711+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
14712+
def test_transform_rb(self, rbclass):
14713+
policy_odd = lambda td: td
14714+
policy_even = lambda td: td
14715+
condition = lambda td: True
14716+
rb = rbclass(storage=LazyTensorStorage(10))
14717+
rb.append_transform(
14718+
ConditionalPolicySwitch(condition=condition, policy=policy_even)
14719+
)
14720+
rb.extend(TensorDict(batch_size=[2]))
14721+
with pytest.raises(
14722+
RuntimeError,
14723+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14724+
):
14725+
rb.sample(2)
14726+
14727+
def test_transform_inverse(self):
14728+
return
14729+
14730+
1452914731
if __name__ == "__main__":
1453014732
args, unknown = argparse.ArgumentParser().parse_known_args()
1453114733
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/batched_envs.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350

351351
# if share_individual_td is None, we will assess later if the output can be stacked
352352
self.share_individual_td = share_individual_td
353+
# self._batch_locked = batch_locked
353354
self._share_memory = shared_memory
354355
self._memmap = memmap
355356
self.allow_step_when_done = allow_step_when_done
@@ -626,8 +627,8 @@ def map_device(key, value, device_map=device_map):
626627
self._env_tensordict.named_apply(
627628
map_device, nested_keys=True, filter_empty=True
628629
)
629-
630-
self._batch_locked = meta_data.batch_locked
630+
# if self._batch_locked is None:
631+
# self._batch_locked = meta_data.batch_locked
631632
else:
632633
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
633634
devices = set()
@@ -668,7 +669,8 @@ def map_device(key, value, device_map=device_map):
668669
self._env_tensordict = torch.stack(
669670
[meta_data.tensordict for meta_data in meta_data], 0
670671
)
671-
self._batch_locked = meta_data[0].batch_locked
672+
# if self._batch_locked is None:
673+
# self._batch_locked = meta_data[0].batch_locked
672674
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
673675

674676
def state_dict(self) -> OrderedDict:

torchrl/envs/custom/chess.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
195195
batch_size=torch.Size([96]),
196196
device=None,
197197
is_shared=False)
198-
199-
200-
"""
198+
""" # noqa: D301
201199

202200
_hash_table: dict[int, str] = {}
203201
_PGN_RESTART = """[Event "?"]

0 commit comments

Comments
 (0)