Skip to content

Commit 60d8b1a

Browse files
mikaylagawareckivmoens
authored andcommitted
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: a68e46a Pull Request resolved: #2711
1 parent 6ae8d43 commit 60d8b1a

File tree

7 files changed

+447
-7
lines changed

7 files changed

+447
-7
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,7 @@ to be able to create this other composition:
11121112
CenterCrop
11131113
ClipTransform
11141114
Compose
1115+
ConditionalPolicySwitch
11151116
ConditionalSkip
11161117
Crop
11171118
DataLoadingPrimer

test/test_transforms.py

+201-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import tensordict.tensordict
2626
import torch
27+
2728
from tensordict import (
2829
assert_close,
2930
LazyStackedTensorDict,
@@ -33,7 +34,7 @@
3334
TensorDictBase,
3435
unravel_key,
3536
)
36-
from tensordict.nn import TensorDictSequential
37+
from tensordict.nn import TensorDictSequential, WrapModule
3738
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
3839
from torch import multiprocessing as mp, nn, Tensor
3940
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
@@ -62,6 +63,7 @@
6263
CenterCrop,
6364
ClipTransform,
6465
Compose,
66+
ConditionalPolicySwitch,
6567
ConditionalSkip,
6668
Crop,
6769
DeviceCastTransform,
@@ -14526,6 +14528,204 @@ def test_can_init_with_fps(self):
1452614528
assert recorder is not None
1452714529

1452814530

14531+
class TestConditionalPolicySwitch(TransformBase):
14532+
def test_single_trans_env_check(self):
14533+
base_env = CountingEnv(max_steps=15)
14534+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
14535+
# Player 0
14536+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
14537+
policy_even = lambda td: td.set("action", env.action_spec.one())
14538+
transforms = Compose(
14539+
StepCounter(),
14540+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14541+
)
14542+
env = base_env.append_transform(transforms)
14543+
env.check_env_specs()
14544+
14545+
def _create_policy_odd(self, base_env):
14546+
return WrapModule(
14547+
lambda td, base_env=base_env: td.set(
14548+
"action", base_env.action_spec_unbatched.zero(td.shape)
14549+
),
14550+
out_keys=["action"],
14551+
)
14552+
14553+
def _create_policy_even(self, base_env):
14554+
return WrapModule(
14555+
lambda td, base_env=base_env: td.set(
14556+
"action", base_env.action_spec_unbatched.one(td.shape)
14557+
),
14558+
out_keys=["action"],
14559+
)
14560+
14561+
def _create_transforms(self, condition, policy_even):
14562+
return Compose(
14563+
StepCounter(),
14564+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14565+
)
14566+
14567+
def _make_env(self, max_count, env_cls):
14568+
torch.manual_seed(0)
14569+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14570+
base_env = env_cls(max_steps=max_count)
14571+
policy_even = self._create_policy_even(base_env)
14572+
transforms = self._create_transforms(condition, policy_even)
14573+
return base_env.append_transform(transforms)
14574+
14575+
def _test_env(self, env, policy_odd):
14576+
env.check_env_specs()
14577+
env.set_seed(0)
14578+
r = env.rollout(100, policy_odd, break_when_any_done=False)
14579+
# Check results are independent: one reset / step in one env should not impact results in another
14580+
r0, r1, r2 = r.unbind(0)
14581+
r0_split = r0.split(6)
14582+
assert all((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])
14583+
r1_split = r1.split(7)
14584+
assert all((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])
14585+
r2_split = r2.split(8)
14586+
assert all((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])
14587+
14588+
def test_trans_serial_env_check(self):
14589+
torch.manual_seed(0)
14590+
base_env = SerialEnv(
14591+
3,
14592+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14593+
)
14594+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14595+
policy_odd = self._create_policy_odd(base_env)
14596+
policy_even = self._create_policy_even(base_env)
14597+
transforms = self._create_transforms(condition, policy_even)
14598+
env = base_env.append_transform(transforms)
14599+
self._test_env(env, policy_odd)
14600+
14601+
def test_trans_parallel_env_check(self):
14602+
torch.manual_seed(0)
14603+
base_env = ParallelEnv(
14604+
3,
14605+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
14606+
mp_start_method=mp_ctx,
14607+
)
14608+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14609+
policy_odd = self._create_policy_odd(base_env)
14610+
policy_even = self._create_policy_even(base_env)
14611+
transforms = self._create_transforms(condition, policy_even)
14612+
env = base_env.append_transform(transforms)
14613+
self._test_env(env, policy_odd)
14614+
14615+
def test_serial_trans_env_check(self):
14616+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14617+
policy_odd = self._create_policy_odd(CountingEnv())
14618+
14619+
def make_env(max_count):
14620+
return partial(self._make_env, max_count, CountingEnv)
14621+
14622+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
14623+
self._test_env(env, policy_odd)
14624+
14625+
def test_parallel_trans_env_check(self):
14626+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
14627+
policy_odd = self._create_policy_odd(CountingEnv())
14628+
14629+
def make_env(max_count):
14630+
return partial(self._make_env, max_count, CountingEnv)
14631+
14632+
env = ParallelEnv(
14633+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
14634+
)
14635+
self._test_env(env, policy_odd)
14636+
14637+
def test_transform_no_env(self):
14638+
policy_odd = lambda td: td
14639+
policy_even = lambda td: td
14640+
condition = lambda td: True
14641+
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
14642+
with pytest.raises(
14643+
RuntimeError,
14644+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14645+
):
14646+
transforms(TensorDict())
14647+
14648+
def test_transform_compose(self):
14649+
policy_odd = lambda td: td
14650+
policy_even = lambda td: td
14651+
condition = lambda td: True
14652+
transforms = Compose(
14653+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14654+
)
14655+
with pytest.raises(
14656+
RuntimeError,
14657+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14658+
):
14659+
transforms(TensorDict())
14660+
14661+
def test_transform_env(self):
14662+
base_env = CountingEnv(max_steps=15)
14663+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
14664+
# Player 0
14665+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
14666+
policy_even = lambda td: td.set("action", env.action_spec.one())
14667+
transforms = Compose(
14668+
StepCounter(),
14669+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14670+
)
14671+
env = base_env.append_transform(transforms)
14672+
env.check_env_specs()
14673+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
14674+
assert r.shape[0] == 15
14675+
assert (r["action"] == 0).all()
14676+
assert (
14677+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
14678+
).all()
14679+
assert r["next", "done"].any()
14680+
14681+
# Player 1
14682+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
14683+
transforms = Compose(
14684+
StepCounter(),
14685+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
14686+
)
14687+
env = base_env.append_transform(transforms)
14688+
r = env.rollout(1000, policy_even, break_when_all_done=True)
14689+
assert r.shape[0] == 16
14690+
assert (r["action"] == 1).all()
14691+
assert (
14692+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
14693+
).all()
14694+
assert r["next", "done"].any()
14695+
14696+
def test_transform_model(self):
14697+
policy_odd = lambda td: td
14698+
policy_even = lambda td: td
14699+
condition = lambda td: True
14700+
transforms = nn.Sequential(
14701+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
14702+
)
14703+
with pytest.raises(
14704+
RuntimeError,
14705+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14706+
):
14707+
transforms(TensorDict())
14708+
14709+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
14710+
def test_transform_rb(self, rbclass):
14711+
policy_odd = lambda td: td
14712+
policy_even = lambda td: td
14713+
condition = lambda td: True
14714+
rb = rbclass(storage=LazyTensorStorage(10))
14715+
rb.append_transform(
14716+
ConditionalPolicySwitch(condition=condition, policy=policy_even)
14717+
)
14718+
rb.extend(TensorDict(batch_size=[2]))
14719+
with pytest.raises(
14720+
RuntimeError,
14721+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
14722+
):
14723+
rb.sample(2)
14724+
14725+
def test_transform_inverse(self):
14726+
return
14727+
14728+
1452914729
if __name__ == "__main__":
1453014730
args, unknown = argparse.ArgumentParser().parse_known_args()
1453114731
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
CenterCrop,
5959
ClipTransform,
6060
Compose,
61+
ConditionalPolicySwitch,
6162
ConditionalSkip,
6263
Crop,
6364
DeviceCastTransform,
@@ -137,6 +138,7 @@
137138
"AutoResetTransform",
138139
"AsyncEnvPool",
139140
"ProcessorAsyncEnvPool",
141+
"ConditionalPolicySwitch",
140142
"ThreadingAsyncEnvPool",
141143
"BatchSizeTransform",
142144
"BinarizeReward",

torchrl/envs/batched_envs.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350

351351
# if share_individual_td is None, we will assess later if the output can be stacked
352352
self.share_individual_td = share_individual_td
353+
# self._batch_locked = batch_locked
353354
self._share_memory = shared_memory
354355
self._memmap = memmap
355356
self.allow_step_when_done = allow_step_when_done
@@ -626,8 +627,8 @@ def map_device(key, value, device_map=device_map):
626627
self._env_tensordict.named_apply(
627628
map_device, nested_keys=True, filter_empty=True
628629
)
629-
630-
self._batch_locked = meta_data.batch_locked
630+
# if self._batch_locked is None:
631+
# self._batch_locked = meta_data.batch_locked
631632
else:
632633
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
633634
devices = set()
@@ -668,7 +669,8 @@ def map_device(key, value, device_map=device_map):
668669
self._env_tensordict = torch.stack(
669670
[meta_data.tensordict for meta_data in meta_data], 0
670671
)
671-
self._batch_locked = meta_data[0].batch_locked
672+
# if self._batch_locked is None:
673+
# self._batch_locked = meta_data[0].batch_locked
672674
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
673675

674676
def state_dict(self) -> OrderedDict:

torchrl/envs/custom/chess.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
195195
batch_size=torch.Size([96]),
196196
device=None,
197197
is_shared=False)
198-
199-
200-
"""
198+
""" # noqa: D301
201199

202200
_hash_table: dict[int, str] = {}
203201
_PGN_RESTART = """[Event "?"]

torchrl/envs/transforms/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ClipTransform,
2222
Compose,
23+
ConditionalPolicySwitch,
2324
ConditionalSkip,
2425
Crop,
2526
DeviceCastTransform,
@@ -83,6 +84,7 @@
8384
"CatFrames",
8485
"CatTensors",
8586
"CenterCrop",
87+
"ConditionalPolicySwitch",
8688
"ClipTransform",
8789
"Compose",
8890
"ConditionalSkip",

0 commit comments

Comments
 (0)