|  | 
| 147 | 147 |     TargetReturn, | 
| 148 | 148 |     TensorDictPrimer, | 
| 149 | 149 |     TimeMaxPool, | 
|  | 150 | +    Tokenizer, | 
| 150 | 151 |     ToTensorImage, | 
| 151 | 152 |     TrajCounter, | 
| 152 | 153 |     TransformedEnv, | 
| @@ -2499,7 +2500,223 @@ def test_transform_rb(self, rbclass): | 
| 2499 | 2500 |         assert ("next", "observation") in td.keys(True) | 
| 2500 | 2501 | 
 | 
| 2501 | 2502 |     def test_transform_inverse(self): | 
| 2502 |  | -        raise pytest.skip("No inverse for Hash") | 
|  | 2503 | +        env = CountingEnv() | 
|  | 2504 | +        env = env.append_transform( | 
|  | 2505 | +            Hash( | 
|  | 2506 | +                in_keys=[], | 
|  | 2507 | +                out_keys=[], | 
|  | 2508 | +                in_keys_inv=["action"], | 
|  | 2509 | +                out_keys_inv=["action_hash"], | 
|  | 2510 | +            ) | 
|  | 2511 | +        ) | 
|  | 2512 | +        assert "action_hash" in env.action_keys | 
|  | 2513 | +        r = env.rollout(3) | 
|  | 2514 | +        env.check_env_specs() | 
|  | 2515 | +        assert "action_hash" in r | 
|  | 2516 | +        assert isinstance(r[0]["action_hash"], torch.Tensor) | 
|  | 2517 | + | 
|  | 2518 | + | 
|  | 2519 | +class TestTokenizer(TransformBase): | 
|  | 2520 | +    @pytest.mark.parametrize("datatype", ["str", "NonTensorStack"]) | 
|  | 2521 | +    def test_transform_no_env(self, datatype): | 
|  | 2522 | +        if datatype == "str": | 
|  | 2523 | +            obs = "abcdefg" | 
|  | 2524 | +        elif datatype == "NonTensorStack": | 
|  | 2525 | +            obs = torch.stack( | 
|  | 2526 | +                [ | 
|  | 2527 | +                    NonTensorData(data="abcde"), | 
|  | 2528 | +                    NonTensorData(data="fghij"), | 
|  | 2529 | +                    NonTensorData(data="klmno"), | 
|  | 2530 | +                ] | 
|  | 2531 | +            ) | 
|  | 2532 | +        else: | 
|  | 2533 | +            raise RuntimeError(f"please add a test case for datatype {datatype}") | 
|  | 2534 | + | 
|  | 2535 | +        td = TensorDict( | 
|  | 2536 | +            { | 
|  | 2537 | +                "observation": obs, | 
|  | 2538 | +            } | 
|  | 2539 | +        ) | 
|  | 2540 | + | 
|  | 2541 | +        t = Tokenizer(in_keys=["observation"], out_keys=["tokens"]) | 
|  | 2542 | +        td_tokenized = t(td) | 
|  | 2543 | +        t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"]) | 
|  | 2544 | +        td_recon = t_inv.inv(td_tokenized.clone().exclude("observation")) | 
|  | 2545 | +        assert td_tokenized.get("observation") is td.get("observation") | 
|  | 2546 | +        assert td_recon["observation"] == td["observation"] | 
|  | 2547 | + | 
|  | 2548 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2549 | +    def test_single_trans_env_check(self, datatype): | 
|  | 2550 | +        if datatype == "str": | 
|  | 2551 | +            t = Tokenizer( | 
|  | 2552 | +                in_keys=["string"], | 
|  | 2553 | +                out_keys=["tokens"], | 
|  | 2554 | +                max_length=5, | 
|  | 2555 | +            ) | 
|  | 2556 | +            base_env = CountingEnvWithString(max_size=4, min_size=4) | 
|  | 2557 | +        env = TransformedEnv(base_env, t) | 
|  | 2558 | +        check_env_specs(env, return_contiguous=False) | 
|  | 2559 | + | 
|  | 2560 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2561 | +    def test_serial_trans_env_check(self, datatype): | 
|  | 2562 | +        def make_env(): | 
|  | 2563 | +            if datatype == "str": | 
|  | 2564 | +                t = Tokenizer( | 
|  | 2565 | +                    in_keys=["string"], | 
|  | 2566 | +                    out_keys=["tokens"], | 
|  | 2567 | +                    max_length=5, | 
|  | 2568 | +                ) | 
|  | 2569 | +                base_env = CountingEnvWithString(max_size=4, min_size=4) | 
|  | 2570 | + | 
|  | 2571 | +            return TransformedEnv(base_env, t) | 
|  | 2572 | + | 
|  | 2573 | +        env = SerialEnv(2, make_env) | 
|  | 2574 | +        check_env_specs(env, return_contiguous=False) | 
|  | 2575 | + | 
|  | 2576 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2577 | +    def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): | 
|  | 2578 | +        def make_env(): | 
|  | 2579 | +            if datatype == "str": | 
|  | 2580 | +                t = Tokenizer( | 
|  | 2581 | +                    in_keys=["string"], | 
|  | 2582 | +                    out_keys=["tokens"], | 
|  | 2583 | +                    max_length=5, | 
|  | 2584 | +                ) | 
|  | 2585 | +                base_env = CountingEnvWithString(max_size=4, min_size=4) | 
|  | 2586 | +            return TransformedEnv(base_env, t) | 
|  | 2587 | + | 
|  | 2588 | +        env = maybe_fork_ParallelEnv(2, make_env) | 
|  | 2589 | +        try: | 
|  | 2590 | +            check_env_specs(env, return_contiguous=False) | 
|  | 2591 | +        finally: | 
|  | 2592 | +            try: | 
|  | 2593 | +                env.close() | 
|  | 2594 | +            except RuntimeError: | 
|  | 2595 | +                pass | 
|  | 2596 | + | 
|  | 2597 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2598 | +    def test_trans_serial_env_check(self, datatype): | 
|  | 2599 | +        if datatype == "str": | 
|  | 2600 | +            t = Tokenizer( | 
|  | 2601 | +                in_keys=["string"], | 
|  | 2602 | +                out_keys=["tokens"], | 
|  | 2603 | +                max_length=5, | 
|  | 2604 | +            ) | 
|  | 2605 | +            base_env = partial(CountingEnvWithString, max_size=4, min_size=4) | 
|  | 2606 | + | 
|  | 2607 | +        env = TransformedEnv(SerialEnv(2, base_env), t) | 
|  | 2608 | +        check_env_specs(env, return_contiguous=False) | 
|  | 2609 | + | 
|  | 2610 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2611 | +    def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): | 
|  | 2612 | +        if datatype == "str": | 
|  | 2613 | +            t = Tokenizer( | 
|  | 2614 | +                in_keys=["string"], | 
|  | 2615 | +                out_keys=["tokens"], | 
|  | 2616 | +                max_length=5, | 
|  | 2617 | +            ) | 
|  | 2618 | +            base_env = partial(CountingEnvWithString, max_size=4, min_size=4) | 
|  | 2619 | + | 
|  | 2620 | +        env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) | 
|  | 2621 | +        try: | 
|  | 2622 | +            check_env_specs(env, return_contiguous=False) | 
|  | 2623 | +        finally: | 
|  | 2624 | +            try: | 
|  | 2625 | +                env.close() | 
|  | 2626 | +            except RuntimeError: | 
|  | 2627 | +                pass | 
|  | 2628 | + | 
|  | 2629 | +    @pytest.mark.parametrize("datatype", ["str"]) | 
|  | 2630 | +    def test_transform_compose(self, datatype): | 
|  | 2631 | +        if datatype == "str": | 
|  | 2632 | +            obs = "abcdefg" | 
|  | 2633 | + | 
|  | 2634 | +        td = TensorDict( | 
|  | 2635 | +            { | 
|  | 2636 | +                "observation": obs, | 
|  | 2637 | +            } | 
|  | 2638 | +        ) | 
|  | 2639 | +        t = Tokenizer( | 
|  | 2640 | +            in_keys=["observation"], | 
|  | 2641 | +            out_keys=["tokens"], | 
|  | 2642 | +            max_length=5, | 
|  | 2643 | +        ) | 
|  | 2644 | +        t = Compose(t) | 
|  | 2645 | +        td_tokenized = t(td) | 
|  | 2646 | + | 
|  | 2647 | +        assert td_tokenized["observation"] is td["observation"] | 
|  | 2648 | +        assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt") | 
|  | 2649 | + | 
|  | 2650 | +    # TODO | 
|  | 2651 | +    def test_transform_model(self): | 
|  | 2652 | +        t = Hash( | 
|  | 2653 | +            in_keys=[("next", "observation"), ("observation",)], | 
|  | 2654 | +            out_keys=[("next", "hashing"), ("hashing",)], | 
|  | 2655 | +            hash_fn=hash, | 
|  | 2656 | +        ) | 
|  | 2657 | +        model = nn.Sequential(t, nn.Identity()) | 
|  | 2658 | +        td = TensorDict( | 
|  | 2659 | +            {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] | 
|  | 2660 | +        ) | 
|  | 2661 | +        td_out = model(td) | 
|  | 2662 | +        assert ("next", "hashing") in td_out.keys(True) | 
|  | 2663 | +        assert ("hashing",) in td_out.keys(True) | 
|  | 2664 | +        assert td_out["next", "hashing"] == hash(td["next", "observation"]) | 
|  | 2665 | +        assert td_out["hashing"] == hash(td["observation"]) | 
|  | 2666 | + | 
|  | 2667 | +    @pytest.mark.skipif(not _has_gym, reason="Gym not found") | 
|  | 2668 | +    def test_transform_env(self): | 
|  | 2669 | +        t = Hash( | 
|  | 2670 | +            in_keys=["observation"], | 
|  | 2671 | +            out_keys=["hashing"], | 
|  | 2672 | +            hash_fn=hash, | 
|  | 2673 | +        ) | 
|  | 2674 | +        env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) | 
|  | 2675 | +        assert env.observation_spec["hashing"] | 
|  | 2676 | +        assert "observation" in env.observation_spec | 
|  | 2677 | +        assert "observation" in env.base_env.observation_spec | 
|  | 2678 | +        check_env_specs(env) | 
|  | 2679 | + | 
|  | 2680 | +    @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) | 
|  | 2681 | +    def test_transform_rb(self, rbclass): | 
|  | 2682 | +        t = Hash( | 
|  | 2683 | +            in_keys=[("next", "observation"), ("observation",)], | 
|  | 2684 | +            out_keys=[("next", "hashing"), ("hashing",)], | 
|  | 2685 | +            hash_fn=lambda x: [hash(x[0]), hash(x[1])], | 
|  | 2686 | +        ) | 
|  | 2687 | +        rb = rbclass(storage=LazyTensorStorage(10)) | 
|  | 2688 | +        rb.append_transform(t) | 
|  | 2689 | +        td = TensorDict( | 
|  | 2690 | +            { | 
|  | 2691 | +                "observation": torch.randn(3, 4), | 
|  | 2692 | +                "next": TensorDict( | 
|  | 2693 | +                    {"observation": torch.randn(3, 4)}, | 
|  | 2694 | +                    [], | 
|  | 2695 | +                ), | 
|  | 2696 | +            }, | 
|  | 2697 | +            [], | 
|  | 2698 | +        ).expand(10) | 
|  | 2699 | +        rb.extend(td) | 
|  | 2700 | +        td = rb.sample(2) | 
|  | 2701 | +        assert "hashing" in td.keys() | 
|  | 2702 | +        assert "observation" in td.keys() | 
|  | 2703 | +        assert ("next", "observation") in td.keys(True) | 
|  | 2704 | + | 
|  | 2705 | +    def test_transform_inverse(self): | 
|  | 2706 | +        env = CountingEnv() | 
|  | 2707 | +        env = env.append_transform( | 
|  | 2708 | +            Hash( | 
|  | 2709 | +                in_keys=[], | 
|  | 2710 | +                out_keys=[], | 
|  | 2711 | +                in_keys_inv=["action"], | 
|  | 2712 | +                out_keys_inv=["action_hash"], | 
|  | 2713 | +            ) | 
|  | 2714 | +        ) | 
|  | 2715 | +        assert "action_hash" in env.action_keys | 
|  | 2716 | +        r = env.rollout(3) | 
|  | 2717 | +        env.check_env_specs() | 
|  | 2718 | +        assert "action_hash" in r | 
|  | 2719 | +        assert isinstance(r[0]["action_hash"], torch.Tensor) | 
| 2503 | 2720 | 
 | 
| 2504 | 2721 | 
 | 
| 2505 | 2722 | class TestStack(TransformBase): | 
|  | 
0 commit comments