|
77 | 77 | Hash,
|
78 | 78 | InitTracker,
|
79 | 79 | LineariseRewards,
|
| 80 | + MultiAction, |
80 | 81 | MultiStepTransform,
|
81 | 82 | NoopResetEnv,
|
82 | 83 | ObservationNorm,
|
|
156 | 157 | MultiKeyCountingEnv,
|
157 | 158 | MultiKeyCountingEnvPolicy,
|
158 | 159 | NestedCountingEnv,
|
| 160 | + StateLessCountingEnv, |
159 | 161 | )
|
160 | 162 | else:
|
161 | 163 | from _utils_internal import ( # noqa
|
|
184 | 186 | MultiKeyCountingEnv,
|
185 | 187 | MultiKeyCountingEnvPolicy,
|
186 | 188 | NestedCountingEnv,
|
| 189 | + StateLessCountingEnv, |
187 | 190 | )
|
188 | 191 |
|
189 | 192 | IS_WIN = platform == "win32"
|
@@ -13664,6 +13667,198 @@ def test_transform_inverse(self):
|
13664 | 13667 | return
|
13665 | 13668 |
|
13666 | 13669 |
|
| 13670 | +class TestMultiAction(TransformBase): |
| 13671 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13672 | + def test_single_trans_env_check(self, bwad): |
| 13673 | + base_env = CountingEnv(max_steps=10) |
| 13674 | + env = TransformedEnv( |
| 13675 | + base_env, |
| 13676 | + Compose( |
| 13677 | + StepCounter(step_count_key="before_count"), |
| 13678 | + MultiAction(), |
| 13679 | + StepCounter(step_count_key="after_count"), |
| 13680 | + ), |
| 13681 | + ) |
| 13682 | + env.check_env_specs() |
| 13683 | + |
| 13684 | + def policy(td): |
| 13685 | + # 3 action per step |
| 13686 | + td["action"] = torch.ones(3, 1) |
| 13687 | + return td |
| 13688 | + |
| 13689 | + r = env.rollout(10, policy) |
| 13690 | + assert r["action"].shape == (4, 3, 1) |
| 13691 | + assert r["next", "done"].any() |
| 13692 | + assert r["next", "done"][-1].all() |
| 13693 | + assert (r["observation"][0] == 0).all() |
| 13694 | + assert (r["next", "observation"][0] == 3).all() |
| 13695 | + assert (r["next", "observation"][-1] == 11).all() |
| 13696 | + # Check that before_count is incremented but not after_count |
| 13697 | + assert r["before_count"].max() == 9 |
| 13698 | + assert r["after_count"].max() == 3 |
| 13699 | + |
| 13700 | + def _batched_trans_env_check(self, cls, bwad, within): |
| 13701 | + if within: |
| 13702 | + |
| 13703 | + def make_env(i): |
| 13704 | + base_env = CountingEnv(max_steps=i) |
| 13705 | + env = TransformedEnv( |
| 13706 | + base_env, |
| 13707 | + Compose( |
| 13708 | + StepCounter(step_count_key="before_count"), |
| 13709 | + MultiAction(), |
| 13710 | + StepCounter(step_count_key="after_count"), |
| 13711 | + ), |
| 13712 | + ) |
| 13713 | + return env |
| 13714 | + |
| 13715 | + env = cls(2, [partial(make_env, i=10), partial(make_env, i=20)]) |
| 13716 | + else: |
| 13717 | + base_env = cls( |
| 13718 | + 2, |
| 13719 | + [ |
| 13720 | + partial(CountingEnv, max_steps=10), |
| 13721 | + partial(CountingEnv, max_steps=20), |
| 13722 | + ], |
| 13723 | + ) |
| 13724 | + env = TransformedEnv( |
| 13725 | + base_env, |
| 13726 | + Compose( |
| 13727 | + StepCounter(step_count_key="before_count"), |
| 13728 | + MultiAction(), |
| 13729 | + StepCounter(step_count_key="after_count"), |
| 13730 | + ), |
| 13731 | + ) |
| 13732 | + |
| 13733 | + try: |
| 13734 | + env.check_env_specs() |
| 13735 | + |
| 13736 | + def policy(td): |
| 13737 | + # 3 action per step |
| 13738 | + td["action"] = torch.ones(2, 3, 1) |
| 13739 | + return td |
| 13740 | + |
| 13741 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13742 | + # r0 |
| 13743 | + r0 = r[0] |
| 13744 | + if bwad: |
| 13745 | + assert r["action"].shape == (2, 4, 3, 1) |
| 13746 | + else: |
| 13747 | + assert r["action"].shape == (2, 10, 3, 1) |
| 13748 | + assert r0["next", "done"].any() |
| 13749 | + if bwad: |
| 13750 | + assert r0["next", "done"][-1].all() |
| 13751 | + else: |
| 13752 | + assert r0["next", "done"].sum() == 2 |
| 13753 | + |
| 13754 | + assert (r0["observation"][0] == 0).all() |
| 13755 | + assert (r0["next", "observation"][0] == 3).all() |
| 13756 | + if bwad: |
| 13757 | + assert (r0["next", "observation"][-1] == 11).all() |
| 13758 | + else: |
| 13759 | + assert (r0["next", "observation"][-1] == 6).all(), r0[ |
| 13760 | + "next", "observation" |
| 13761 | + ] |
| 13762 | + # Check that before_count is incremented but not after_count |
| 13763 | + assert r0["before_count"].max() == 9 |
| 13764 | + assert r0["after_count"].max() == 3 |
| 13765 | + # r1 |
| 13766 | + r1 = r[1] |
| 13767 | + if bwad: |
| 13768 | + assert not r1["next", "done"].any() |
| 13769 | + else: |
| 13770 | + assert r1["next", "done"].any() |
| 13771 | + assert r1["next", "done"].sum() == 1 |
| 13772 | + assert (r1["observation"][0] == 0).all() |
| 13773 | + assert (r1["next", "observation"][0] == 3).all() |
| 13774 | + if bwad: |
| 13775 | + # r0 cannot go above 11 but r1 can - so we see a 12 because one more step was done |
| 13776 | + assert (r1["next", "observation"][-1] == 12).all() |
| 13777 | + else: |
| 13778 | + assert (r1["next", "observation"][-1] == 9).all() |
| 13779 | + # Check that before_count is incremented but not after_count |
| 13780 | + if bwad: |
| 13781 | + assert r1["before_count"].max() == 9 |
| 13782 | + assert r1["after_count"].max() == 3 |
| 13783 | + else: |
| 13784 | + assert r1["before_count"].max() == 18 |
| 13785 | + assert r1["after_count"].max() == 6 |
| 13786 | + finally: |
| 13787 | + env.close() |
| 13788 | + |
| 13789 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13790 | + def test_serial_trans_env_check(self, bwad): |
| 13791 | + self._batched_trans_env_check(SerialEnv, bwad, within=True) |
| 13792 | + |
| 13793 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13794 | + def test_parallel_trans_env_check(self, bwad): |
| 13795 | + self._batched_trans_env_check( |
| 13796 | + partial(ParallelEnv, mp_start_method=mp_ctx), bwad, within=True |
| 13797 | + ) |
| 13798 | + |
| 13799 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13800 | + def test_trans_serial_env_check(self, bwad): |
| 13801 | + self._batched_trans_env_check(SerialEnv, bwad, within=False) |
| 13802 | + |
| 13803 | + @pytest.mark.parametrize("bwad", [True, False]) |
| 13804 | + @pytest.mark.parametrize("buffers", [True, False]) |
| 13805 | + def test_trans_parallel_env_check(self, bwad, buffers): |
| 13806 | + self._batched_trans_env_check( |
| 13807 | + partial(ParallelEnv, use_buffers=buffers, mp_start_method=mp_ctx), |
| 13808 | + bwad, |
| 13809 | + within=False, |
| 13810 | + ) |
| 13811 | + |
| 13812 | + def test_transform_no_env(self): |
| 13813 | + ... |
| 13814 | + |
| 13815 | + def test_transform_compose(self): |
| 13816 | + ... |
| 13817 | + |
| 13818 | + @pytest.mark.parametrize("bwad", [True, False]) |
| 13819 | + def test_transform_env(self, bwad): |
| 13820 | + # tests stateless (batch-unlocked) envs |
| 13821 | + torch.manual_seed(0) |
| 13822 | + env = StateLessCountingEnv() |
| 13823 | + |
| 13824 | + def policy(td): |
| 13825 | + td["action"] = torch.ones(td.shape + (1,)) |
| 13826 | + return td |
| 13827 | + |
| 13828 | + r = env.rollout( |
| 13829 | + 10, |
| 13830 | + tensordict=env.reset().expand(4), |
| 13831 | + auto_reset=False, |
| 13832 | + break_when_any_done=False, |
| 13833 | + policy=policy, |
| 13834 | + ) |
| 13835 | + assert (r["count"] == torch.arange(10).expand(4, 10).view(4, 10, 1)).all() |
| 13836 | + td_reset = env.reset().expand(4).clone() |
| 13837 | + td_reset["max_count"] = torch.arange(4, 8).view(4, 1) |
| 13838 | + env = TransformedEnv(env, MultiAction()) |
| 13839 | + |
| 13840 | + def policy(td): |
| 13841 | + td["action"] = torch.ones(td.shape + (3,) + (1,)) |
| 13842 | + return td |
| 13843 | + |
| 13844 | + r = env.rollout( |
| 13845 | + 20, |
| 13846 | + policy=policy, |
| 13847 | + auto_reset=False, |
| 13848 | + tensordict=td_reset, |
| 13849 | + break_when_any_done=bwad, |
| 13850 | + ) |
| 13851 | + |
| 13852 | + def test_transform_model(self): |
| 13853 | + ... |
| 13854 | + |
| 13855 | + def test_transform_rb(self): |
| 13856 | + return |
| 13857 | + |
| 13858 | + def test_transform_inverse(self): |
| 13859 | + return |
| 13860 | + |
| 13861 | + |
13667 | 13862 | if __name__ == "__main__":
|
13668 | 13863 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
13669 | 13864 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments