|
24 | 24 |
|
25 | 25 | import tensordict.tensordict
|
26 | 26 | import torch
|
| 27 | + |
27 | 28 | from tensordict import (
|
28 | 29 | assert_close,
|
29 | 30 | LazyStackedTensorDict,
|
|
33 | 34 | TensorDictBase,
|
34 | 35 | unravel_key,
|
35 | 36 | )
|
36 |
| -from tensordict.nn import TensorDictSequential |
| 37 | +from tensordict.nn import TensorDictSequential, WrapModule |
37 | 38 | from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
|
38 | 39 | from torch import multiprocessing as mp, nn, Tensor
|
39 | 40 | from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
|
|
62 | 63 | CenterCrop,
|
63 | 64 | ClipTransform,
|
64 | 65 | Compose,
|
| 66 | + ConditionalPolicySwitch, |
65 | 67 | ConditionalSkip,
|
66 | 68 | Crop,
|
67 | 69 | DeviceCastTransform,
|
@@ -14526,6 +14528,206 @@ def test_can_init_with_fps(self):
|
14526 | 14528 | assert recorder is not None
|
14527 | 14529 |
|
14528 | 14530 |
|
| 14531 | +class TestConditionalPolicySwitch(TransformBase): |
| 14532 | + def test_single_trans_env_check(self): |
| 14533 | + base_env = CountingEnv(max_steps=15) |
| 14534 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 14535 | + # Player 0 |
| 14536 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 14537 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 14538 | + transforms = Compose( |
| 14539 | + StepCounter(), |
| 14540 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 14541 | + ) |
| 14542 | + env = base_env.append_transform(transforms) |
| 14543 | + env.check_env_specs() |
| 14544 | + |
| 14545 | + def _create_policy_odd(self, base_env): |
| 14546 | + return WrapModule( |
| 14547 | + lambda td, base_env=base_env: td.set( |
| 14548 | + "action", base_env.action_spec_unbatched.zero(td.shape) |
| 14549 | + ), |
| 14550 | + out_keys=["action"], |
| 14551 | + ) |
| 14552 | + |
| 14553 | + def _create_policy_even(self, base_env): |
| 14554 | + return WrapModule( |
| 14555 | + lambda td, base_env=base_env: td.set( |
| 14556 | + "action", base_env.action_spec_unbatched.one(td.shape) |
| 14557 | + ), |
| 14558 | + out_keys=["action"], |
| 14559 | + ) |
| 14560 | + |
| 14561 | + def _create_transforms(self, condition, policy_even): |
| 14562 | + return Compose( |
| 14563 | + StepCounter(), |
| 14564 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 14565 | + ) |
| 14566 | + |
| 14567 | + def _make_env(self, max_count, env_cls): |
| 14568 | + torch.manual_seed(0) |
| 14569 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 14570 | + base_env = env_cls(max_steps=max_count) |
| 14571 | + policy_even = self._create_policy_even(base_env) |
| 14572 | + transforms = self._create_transforms(condition, policy_even) |
| 14573 | + return base_env.append_transform(transforms) |
| 14574 | + |
| 14575 | + def _test_env(self, env, policy_odd): |
| 14576 | + env.check_env_specs() |
| 14577 | + env.set_seed(0) |
| 14578 | + r = env.rollout(100, policy_odd, break_when_any_done=False) |
| 14579 | + # Check results are independent: one reset / step in one env should not impact results in another |
| 14580 | + r0, r1, r2 = r.unbind(0) |
| 14581 | + r0_split = r0.split(6) |
| 14582 | + assert all((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]) |
| 14583 | + r1_split = r1.split(7) |
| 14584 | + assert all((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]) |
| 14585 | + r2_split = r2.split(8) |
| 14586 | + assert all((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]) |
| 14587 | + |
| 14588 | + def test_trans_serial_env_check(self): |
| 14589 | + torch.manual_seed(0) |
| 14590 | + base_env = SerialEnv( |
| 14591 | + 3, |
| 14592 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 14593 | + batch_locked=False, |
| 14594 | + ) |
| 14595 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 14596 | + policy_odd = self._create_policy_odd(base_env) |
| 14597 | + policy_even = self._create_policy_even(base_env) |
| 14598 | + transforms = self._create_transforms(condition, policy_even) |
| 14599 | + env = base_env.append_transform(transforms) |
| 14600 | + self._test_env(env, policy_odd) |
| 14601 | + |
| 14602 | + def test_trans_parallel_env_check(self): |
| 14603 | + torch.manual_seed(0) |
| 14604 | + base_env = ParallelEnv( |
| 14605 | + 3, |
| 14606 | + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], |
| 14607 | + batch_locked=False, |
| 14608 | + mp_start_method=mp_ctx, |
| 14609 | + ) |
| 14610 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 14611 | + policy_odd = self._create_policy_odd(base_env) |
| 14612 | + policy_even = self._create_policy_even(base_env) |
| 14613 | + transforms = self._create_transforms(condition, policy_even) |
| 14614 | + env = base_env.append_transform(transforms) |
| 14615 | + self._test_env(env, policy_odd) |
| 14616 | + |
| 14617 | + def test_serial_trans_env_check(self): |
| 14618 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 14619 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 14620 | + |
| 14621 | + def make_env(max_count): |
| 14622 | + return partial(self._make_env, max_count, CountingEnv) |
| 14623 | + |
| 14624 | + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) |
| 14625 | + self._test_env(env, policy_odd) |
| 14626 | + |
| 14627 | + def test_parallel_trans_env_check(self): |
| 14628 | + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) |
| 14629 | + policy_odd = self._create_policy_odd(CountingEnv()) |
| 14630 | + |
| 14631 | + def make_env(max_count): |
| 14632 | + return partial(self._make_env, max_count, CountingEnv) |
| 14633 | + |
| 14634 | + env = ParallelEnv( |
| 14635 | + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx |
| 14636 | + ) |
| 14637 | + self._test_env(env, policy_odd) |
| 14638 | + |
| 14639 | + def test_transform_no_env(self): |
| 14640 | + policy_odd = lambda td: td |
| 14641 | + policy_even = lambda td: td |
| 14642 | + condition = lambda td: True |
| 14643 | + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 14644 | + with pytest.raises( |
| 14645 | + RuntimeError, |
| 14646 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 14647 | + ): |
| 14648 | + transforms(TensorDict()) |
| 14649 | + |
| 14650 | + def test_transform_compose(self): |
| 14651 | + policy_odd = lambda td: td |
| 14652 | + policy_even = lambda td: td |
| 14653 | + condition = lambda td: True |
| 14654 | + transforms = Compose( |
| 14655 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 14656 | + ) |
| 14657 | + with pytest.raises( |
| 14658 | + RuntimeError, |
| 14659 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 14660 | + ): |
| 14661 | + transforms(TensorDict()) |
| 14662 | + |
| 14663 | + def test_transform_env(self): |
| 14664 | + base_env = CountingEnv(max_steps=15) |
| 14665 | + condition = lambda td: ((td.get("step_count") % 2) == 0).all() |
| 14666 | + # Player 0 |
| 14667 | + policy_odd = lambda td: td.set("action", env.action_spec.zero()) |
| 14668 | + policy_even = lambda td: td.set("action", env.action_spec.one()) |
| 14669 | + transforms = Compose( |
| 14670 | + StepCounter(), |
| 14671 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 14672 | + ) |
| 14673 | + env = base_env.append_transform(transforms) |
| 14674 | + env.check_env_specs() |
| 14675 | + r = env.rollout(1000, policy_odd, break_when_all_done=True) |
| 14676 | + assert r.shape[0] == 15 |
| 14677 | + assert (r["action"] == 0).all() |
| 14678 | + assert ( |
| 14679 | + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) |
| 14680 | + ).all() |
| 14681 | + assert r["next", "done"].any() |
| 14682 | + |
| 14683 | + # Player 1 |
| 14684 | + condition = lambda td: ((td.get("step_count") % 2) == 1).all() |
| 14685 | + transforms = Compose( |
| 14686 | + StepCounter(), |
| 14687 | + ConditionalPolicySwitch(condition=condition, policy=policy_odd), |
| 14688 | + ) |
| 14689 | + env = base_env.append_transform(transforms) |
| 14690 | + r = env.rollout(1000, policy_even, break_when_all_done=True) |
| 14691 | + assert r.shape[0] == 16 |
| 14692 | + assert (r["action"] == 1).all() |
| 14693 | + assert ( |
| 14694 | + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) |
| 14695 | + ).all() |
| 14696 | + assert r["next", "done"].any() |
| 14697 | + |
| 14698 | + def test_transform_model(self): |
| 14699 | + policy_odd = lambda td: td |
| 14700 | + policy_even = lambda td: td |
| 14701 | + condition = lambda td: True |
| 14702 | + transforms = nn.Sequential( |
| 14703 | + ConditionalPolicySwitch(condition=condition, policy=policy_even), |
| 14704 | + ) |
| 14705 | + with pytest.raises( |
| 14706 | + RuntimeError, |
| 14707 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 14708 | + ): |
| 14709 | + transforms(TensorDict()) |
| 14710 | + |
| 14711 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 14712 | + def test_transform_rb(self, rbclass): |
| 14713 | + policy_odd = lambda td: td |
| 14714 | + policy_even = lambda td: td |
| 14715 | + condition = lambda td: True |
| 14716 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 14717 | + rb.append_transform( |
| 14718 | + ConditionalPolicySwitch(condition=condition, policy=policy_even) |
| 14719 | + ) |
| 14720 | + rb.extend(TensorDict(batch_size=[2])) |
| 14721 | + with pytest.raises( |
| 14722 | + RuntimeError, |
| 14723 | + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", |
| 14724 | + ): |
| 14725 | + rb.sample(2) |
| 14726 | + |
| 14727 | + def test_transform_inverse(self): |
| 14728 | + return |
| 14729 | + |
| 14730 | + |
14529 | 14731 | if __name__ == "__main__":
|
14530 | 14732 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
14531 | 14733 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments