|
147 | 147 | TargetReturn,
|
148 | 148 | TensorDictPrimer,
|
149 | 149 | TimeMaxPool,
|
| 150 | + Tokenizer, |
150 | 151 | ToTensorImage,
|
151 | 152 | TrajCounter,
|
152 | 153 | TransformedEnv,
|
@@ -2420,7 +2421,223 @@ def test_transform_rb(self, rbclass):
|
2420 | 2421 | assert ("next", "observation") in td.keys(True)
|
2421 | 2422 |
|
2422 | 2423 | def test_transform_inverse(self):
|
2423 |
| - raise pytest.skip("No inverse for Hash") |
| 2424 | + env = CountingEnv() |
| 2425 | + env = env.append_transform( |
| 2426 | + Hash( |
| 2427 | + in_keys=[], |
| 2428 | + out_keys=[], |
| 2429 | + in_keys_inv=["action"], |
| 2430 | + out_keys_inv=["action_hash"], |
| 2431 | + ) |
| 2432 | + ) |
| 2433 | + assert "action_hash" in env.action_keys |
| 2434 | + r = env.rollout(3) |
| 2435 | + env.check_env_specs() |
| 2436 | + assert "action_hash" in r |
| 2437 | + assert isinstance(r[0]["action_hash"], torch.Tensor) |
| 2438 | + |
| 2439 | + |
| 2440 | +class TestTokenizer(TransformBase): |
| 2441 | + @pytest.mark.parametrize("datatype", ["str", "NonTensorStack"]) |
| 2442 | + def test_transform_no_env(self, datatype): |
| 2443 | + if datatype == "str": |
| 2444 | + obs = "abcdefg" |
| 2445 | + elif datatype == "NonTensorStack": |
| 2446 | + obs = torch.stack( |
| 2447 | + [ |
| 2448 | + NonTensorData(data="abcde"), |
| 2449 | + NonTensorData(data="fghij"), |
| 2450 | + NonTensorData(data="klmno"), |
| 2451 | + ] |
| 2452 | + ) |
| 2453 | + else: |
| 2454 | + raise RuntimeError(f"please add a test case for datatype {datatype}") |
| 2455 | + |
| 2456 | + td = TensorDict( |
| 2457 | + { |
| 2458 | + "observation": obs, |
| 2459 | + } |
| 2460 | + ) |
| 2461 | + |
| 2462 | + t = Tokenizer(in_keys=["observation"], out_keys=["tokens"]) |
| 2463 | + td_tokenized = t(td) |
| 2464 | + t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"]) |
| 2465 | + td_recon = t_inv.inv(td_tokenized.clone().exclude("observation")) |
| 2466 | + assert td_tokenized.get("observation") is td.get("observation") |
| 2467 | + assert td_recon["observation"] == td["observation"] |
| 2468 | + |
| 2469 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2470 | + def test_single_trans_env_check(self, datatype): |
| 2471 | + if datatype == "str": |
| 2472 | + t = Tokenizer( |
| 2473 | + in_keys=["string"], |
| 2474 | + out_keys=["tokens"], |
| 2475 | + max_length=5, |
| 2476 | + ) |
| 2477 | + base_env = CountingEnvWithString(max_size=4, min_size=4) |
| 2478 | + env = TransformedEnv(base_env, t) |
| 2479 | + check_env_specs(env, return_contiguous=False) |
| 2480 | + |
| 2481 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2482 | + def test_serial_trans_env_check(self, datatype): |
| 2483 | + def make_env(): |
| 2484 | + if datatype == "str": |
| 2485 | + t = Tokenizer( |
| 2486 | + in_keys=["string"], |
| 2487 | + out_keys=["tokens"], |
| 2488 | + max_length=5, |
| 2489 | + ) |
| 2490 | + base_env = CountingEnvWithString(max_size=4, min_size=4) |
| 2491 | + |
| 2492 | + return TransformedEnv(base_env, t) |
| 2493 | + |
| 2494 | + env = SerialEnv(2, make_env) |
| 2495 | + check_env_specs(env, return_contiguous=False) |
| 2496 | + |
| 2497 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2498 | + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): |
| 2499 | + def make_env(): |
| 2500 | + if datatype == "str": |
| 2501 | + t = Tokenizer( |
| 2502 | + in_keys=["string"], |
| 2503 | + out_keys=["tokens"], |
| 2504 | + max_length=5, |
| 2505 | + ) |
| 2506 | + base_env = CountingEnvWithString(max_size=4, min_size=4) |
| 2507 | + return TransformedEnv(base_env, t) |
| 2508 | + |
| 2509 | + env = maybe_fork_ParallelEnv(2, make_env) |
| 2510 | + try: |
| 2511 | + check_env_specs(env, return_contiguous=False) |
| 2512 | + finally: |
| 2513 | + try: |
| 2514 | + env.close() |
| 2515 | + except RuntimeError: |
| 2516 | + pass |
| 2517 | + |
| 2518 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2519 | + def test_trans_serial_env_check(self, datatype): |
| 2520 | + if datatype == "str": |
| 2521 | + t = Tokenizer( |
| 2522 | + in_keys=["string"], |
| 2523 | + out_keys=["tokens"], |
| 2524 | + max_length=5, |
| 2525 | + ) |
| 2526 | + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) |
| 2527 | + |
| 2528 | + env = TransformedEnv(SerialEnv(2, base_env), t) |
| 2529 | + check_env_specs(env, return_contiguous=False) |
| 2530 | + |
| 2531 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2532 | + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): |
| 2533 | + if datatype == "str": |
| 2534 | + t = Tokenizer( |
| 2535 | + in_keys=["string"], |
| 2536 | + out_keys=["tokens"], |
| 2537 | + max_length=5, |
| 2538 | + ) |
| 2539 | + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) |
| 2540 | + |
| 2541 | + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) |
| 2542 | + try: |
| 2543 | + check_env_specs(env, return_contiguous=False) |
| 2544 | + finally: |
| 2545 | + try: |
| 2546 | + env.close() |
| 2547 | + except RuntimeError: |
| 2548 | + pass |
| 2549 | + |
| 2550 | + @pytest.mark.parametrize("datatype", ["str"]) |
| 2551 | + def test_transform_compose(self, datatype): |
| 2552 | + if datatype == "str": |
| 2553 | + obs = "abcdefg" |
| 2554 | + |
| 2555 | + td = TensorDict( |
| 2556 | + { |
| 2557 | + "observation": obs, |
| 2558 | + } |
| 2559 | + ) |
| 2560 | + t = Tokenizer( |
| 2561 | + in_keys=["observation"], |
| 2562 | + out_keys=["tokens"], |
| 2563 | + max_length=5, |
| 2564 | + ) |
| 2565 | + t = Compose(t) |
| 2566 | + td_tokenized = t(td) |
| 2567 | + |
| 2568 | + assert td_tokenized["observation"] is td["observation"] |
| 2569 | + assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt") |
| 2570 | + |
| 2571 | + # TODO |
| 2572 | + def test_transform_model(self): |
| 2573 | + t = Hash( |
| 2574 | + in_keys=[("next", "observation"), ("observation",)], |
| 2575 | + out_keys=[("next", "hashing"), ("hashing",)], |
| 2576 | + hash_fn=hash, |
| 2577 | + ) |
| 2578 | + model = nn.Sequential(t, nn.Identity()) |
| 2579 | + td = TensorDict( |
| 2580 | + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] |
| 2581 | + ) |
| 2582 | + td_out = model(td) |
| 2583 | + assert ("next", "hashing") in td_out.keys(True) |
| 2584 | + assert ("hashing",) in td_out.keys(True) |
| 2585 | + assert td_out["next", "hashing"] == hash(td["next", "observation"]) |
| 2586 | + assert td_out["hashing"] == hash(td["observation"]) |
| 2587 | + |
| 2588 | + @pytest.mark.skipif(not _has_gym, reason="Gym not found") |
| 2589 | + def test_transform_env(self): |
| 2590 | + t = Hash( |
| 2591 | + in_keys=["observation"], |
| 2592 | + out_keys=["hashing"], |
| 2593 | + hash_fn=hash, |
| 2594 | + ) |
| 2595 | + env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t) |
| 2596 | + assert env.observation_spec["hashing"] |
| 2597 | + assert "observation" in env.observation_spec |
| 2598 | + assert "observation" in env.base_env.observation_spec |
| 2599 | + check_env_specs(env) |
| 2600 | + |
| 2601 | + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) |
| 2602 | + def test_transform_rb(self, rbclass): |
| 2603 | + t = Hash( |
| 2604 | + in_keys=[("next", "observation"), ("observation",)], |
| 2605 | + out_keys=[("next", "hashing"), ("hashing",)], |
| 2606 | + hash_fn=lambda x: [hash(x[0]), hash(x[1])], |
| 2607 | + ) |
| 2608 | + rb = rbclass(storage=LazyTensorStorage(10)) |
| 2609 | + rb.append_transform(t) |
| 2610 | + td = TensorDict( |
| 2611 | + { |
| 2612 | + "observation": torch.randn(3, 4), |
| 2613 | + "next": TensorDict( |
| 2614 | + {"observation": torch.randn(3, 4)}, |
| 2615 | + [], |
| 2616 | + ), |
| 2617 | + }, |
| 2618 | + [], |
| 2619 | + ).expand(10) |
| 2620 | + rb.extend(td) |
| 2621 | + td = rb.sample(2) |
| 2622 | + assert "hashing" in td.keys() |
| 2623 | + assert "observation" in td.keys() |
| 2624 | + assert ("next", "observation") in td.keys(True) |
| 2625 | + |
| 2626 | + def test_transform_inverse(self): |
| 2627 | + env = CountingEnv() |
| 2628 | + env = env.append_transform( |
| 2629 | + Hash( |
| 2630 | + in_keys=[], |
| 2631 | + out_keys=[], |
| 2632 | + in_keys_inv=["action"], |
| 2633 | + out_keys_inv=["action_hash"], |
| 2634 | + ) |
| 2635 | + ) |
| 2636 | + assert "action_hash" in env.action_keys |
| 2637 | + r = env.rollout(3) |
| 2638 | + env.check_env_specs() |
| 2639 | + assert "action_hash" in r |
| 2640 | + assert isinstance(r[0]["action_hash"], torch.Tensor) |
2424 | 2641 |
|
2425 | 2642 |
|
2426 | 2643 | class TestStack(TransformBase):
|
|
0 commit comments