|
58 | 58 | CenterCrop,
|
59 | 59 | ClipTransform,
|
60 | 60 | Compose,
|
| 61 | + ConditionalSkip, |
61 | 62 | Crop,
|
62 | 63 | DeviceCastTransform,
|
63 | 64 | DiscreteActionProjection,
|
@@ -13451,6 +13452,159 @@ def test_composite_reward_spec(self) -> None:
|
13451 | 13452 | assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
|
13452 | 13453 |
|
13453 | 13454 |
|
| 13455 | +class TestConditionalSkip(TransformBase): |
| 13456 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13457 | + def test_single_trans_env_check(self, bwad): |
| 13458 | + env = CountingEnv() |
| 13459 | + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) |
| 13460 | + env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False) |
| 13461 | + env = env.append_transform( |
| 13462 | + ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1) |
| 13463 | + ) |
| 13464 | + env.set_seed(0) |
| 13465 | + env.check_env_specs() |
| 13466 | + policy = lambda td: td.set("action", torch.ones((1,))) |
| 13467 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13468 | + assert (r["step_count"] == torch.arange(10).view(10, 1)).all() |
| 13469 | + assert (r["other_count"] == torch.arange(1, 11).view(10, 1) // 2).all() |
| 13470 | + |
| 13471 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13472 | + def test_serial_trans_env_check(self, bwad): |
| 13473 | + def make_env(i): |
| 13474 | + env = CountingEnv() |
| 13475 | + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) |
| 13476 | + return TransformedEnv( |
| 13477 | + base_env, |
| 13478 | + Compose( |
| 13479 | + StepCounter(), |
| 13480 | + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), |
| 13481 | + ), |
| 13482 | + auto_unwrap=False, |
| 13483 | + ) |
| 13484 | + |
| 13485 | + env = SerialEnv(2, [partial(make_env, i=0), partial(make_env, i=1)]) |
| 13486 | + env.check_env_specs() |
| 13487 | + policy = lambda td: td.set("action", torch.ones((2, 1))) |
| 13488 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13489 | + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() |
| 13490 | + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() |
| 13491 | + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() |
| 13492 | + |
| 13493 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13494 | + def test_parallel_trans_env_check(self, bwad): |
| 13495 | + def make_env(i): |
| 13496 | + env = CountingEnv() |
| 13497 | + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) |
| 13498 | + return TransformedEnv( |
| 13499 | + base_env, |
| 13500 | + Compose( |
| 13501 | + StepCounter(), |
| 13502 | + ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)), |
| 13503 | + ), |
| 13504 | + auto_unwrap=False, |
| 13505 | + ) |
| 13506 | + |
| 13507 | + env = ParallelEnv( |
| 13508 | + 2, [partial(make_env, i=0), partial(make_env, i=1)], mp_start_method=mp_ctx |
| 13509 | + ) |
| 13510 | + try: |
| 13511 | + env.check_env_specs() |
| 13512 | + policy = lambda td: td.set("action", torch.ones((2, 1))) |
| 13513 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13514 | + assert ( |
| 13515 | + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) |
| 13516 | + ).all() |
| 13517 | + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() |
| 13518 | + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() |
| 13519 | + finally: |
| 13520 | + env.close() |
| 13521 | + del env |
| 13522 | + |
| 13523 | + @pytest.mark.parametrize("bwad", [False, True]) |
| 13524 | + def test_trans_serial_env_check(self, bwad): |
| 13525 | + def make_env(): |
| 13526 | + env = CountingEnv(max_steps=100) |
| 13527 | + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) |
| 13528 | + return base_env |
| 13529 | + |
| 13530 | + base_env = SerialEnv(2, [make_env, make_env]) |
| 13531 | + |
| 13532 | + def cond(td): |
| 13533 | + sc = td["step_count"] + torch.tensor([[0], [1]]) |
| 13534 | + return sc.squeeze() % 2 == 0 |
| 13535 | + |
| 13536 | + env = TransformedEnv(base_env, Compose(StepCounter(), ConditionalSkip(cond))) |
| 13537 | + env.check_env_specs() |
| 13538 | + policy = lambda td: td.set("action", torch.ones((2, 1))) |
| 13539 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13540 | + assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all() |
| 13541 | + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() |
| 13542 | + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() |
| 13543 | + |
| 13544 | + @pytest.mark.parametrize("bwad", [True, False]) |
| 13545 | + @pytest.mark.parametrize("buffers", [True, False]) |
| 13546 | + def test_trans_parallel_env_check(self, bwad, buffers): |
| 13547 | + def make_env(): |
| 13548 | + env = CountingEnv(max_steps=100) |
| 13549 | + base_env = TransformedEnv(env, StepCounter(step_count_key="other_count")) |
| 13550 | + return base_env |
| 13551 | + |
| 13552 | + base_env = ParallelEnv( |
| 13553 | + 2, [make_env, make_env], mp_start_method=mp_ctx, use_buffers=buffers |
| 13554 | + ) |
| 13555 | + try: |
| 13556 | + |
| 13557 | + def cond(td): |
| 13558 | + sc = td["step_count"] + torch.tensor([[0], [1]]) |
| 13559 | + return sc.squeeze() % 2 == 0 |
| 13560 | + |
| 13561 | + env = TransformedEnv( |
| 13562 | + base_env, Compose(StepCounter(), ConditionalSkip(cond)) |
| 13563 | + ) |
| 13564 | + env.check_env_specs() |
| 13565 | + policy = lambda td: td.set("action", torch.ones((2, 1))) |
| 13566 | + r = env.rollout(10, policy, break_when_any_done=bwad) |
| 13567 | + assert ( |
| 13568 | + r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1) |
| 13569 | + ).all() |
| 13570 | + assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all() |
| 13571 | + assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all() |
| 13572 | + finally: |
| 13573 | + base_env.close() |
| 13574 | + del base_env |
| 13575 | + |
| 13576 | + def test_transform_no_env(self): |
| 13577 | + t = ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) |
| 13578 | + assert not t._inv_call(TensorDict())["_step"] |
| 13579 | + assert t._inv_call(TensorDict())["_step"].shape == () |
| 13580 | + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) |
| 13581 | + |
| 13582 | + def test_transform_compose(self): |
| 13583 | + t = Compose( |
| 13584 | + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) |
| 13585 | + ) |
| 13586 | + assert not t._inv_call(TensorDict())["_step"] |
| 13587 | + assert t._inv_call(TensorDict())["_step"].shape == () |
| 13588 | + assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3) |
| 13589 | + |
| 13590 | + def test_transform_env(self): |
| 13591 | + # tested above |
| 13592 | + return |
| 13593 | + |
| 13594 | + def test_transform_model(self): |
| 13595 | + t = Compose( |
| 13596 | + ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0) |
| 13597 | + ) |
| 13598 | + with pytest.raises(NotImplementedError): |
| 13599 | + t(TensorDict())["_step"] |
| 13600 | + |
| 13601 | + def test_transform_rb(self): |
| 13602 | + return |
| 13603 | + |
| 13604 | + def test_transform_inverse(self): |
| 13605 | + return |
| 13606 | + |
| 13607 | + |
13454 | 13608 | if __name__ == "__main__":
|
13455 | 13609 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
13456 | 13610 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments