Skip to content

Commit 70ab423

Browse files
author
Vincent Moens
committed
[Feature] UnaryTransform for input entries
ghstack-source-id: a33c5a5 Pull Request resolved: #2700
1 parent c983dcf commit 70ab423

File tree

5 files changed

+468
-13
lines changed

5 files changed

+468
-13
lines changed

test/mocking_classes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,20 @@ def _step(
10701070

10711071
class CountingEnvWithString(CountingEnv):
10721072
def __init__(self, *args, **kwargs):
1073+
self.max_size = kwargs.pop("max_size", 30)
1074+
self.min_size = kwargs.pop("min_size", 4)
10731075
super().__init__(*args, **kwargs)
10741076
self.observation_spec.set(
10751077
"string",
10761078
NonTensor(
10771079
shape=self.batch_size,
10781080
device=self.device,
1081+
example_data=self.get_random_string(),
10791082
),
10801083
)
10811084

10821085
def get_random_string(self):
1083-
size = random.randint(4, 30)
1086+
size = random.randint(self.min_size, self.max_size)
10841087
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
10851088

10861089
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

test/test_transforms.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
TargetReturn,
148148
TensorDictPrimer,
149149
TimeMaxPool,
150+
Tokenizer,
150151
ToTensorImage,
151152
TrajCounter,
152153
TransformedEnv,
@@ -2499,7 +2500,223 @@ def test_transform_rb(self, rbclass):
24992500
assert ("next", "observation") in td.keys(True)
25002501

25012502
def test_transform_inverse(self):
2502-
raise pytest.skip("No inverse for Hash")
2503+
env = CountingEnv()
2504+
env = env.append_transform(
2505+
Hash(
2506+
in_keys=[],
2507+
out_keys=[],
2508+
in_keys_inv=["action"],
2509+
out_keys_inv=["action_hash"],
2510+
)
2511+
)
2512+
assert "action_hash" in env.action_keys
2513+
r = env.rollout(3)
2514+
env.check_env_specs()
2515+
assert "action_hash" in r
2516+
assert isinstance(r[0]["action_hash"], torch.Tensor)
2517+
2518+
2519+
class TestTokenizer(TransformBase):
2520+
@pytest.mark.parametrize("datatype", ["str", "NonTensorStack"])
2521+
def test_transform_no_env(self, datatype):
2522+
if datatype == "str":
2523+
obs = "abcdefg"
2524+
elif datatype == "NonTensorStack":
2525+
obs = torch.stack(
2526+
[
2527+
NonTensorData(data="abcde"),
2528+
NonTensorData(data="fghij"),
2529+
NonTensorData(data="klmno"),
2530+
]
2531+
)
2532+
else:
2533+
raise RuntimeError(f"please add a test case for datatype {datatype}")
2534+
2535+
td = TensorDict(
2536+
{
2537+
"observation": obs,
2538+
}
2539+
)
2540+
2541+
t = Tokenizer(in_keys=["observation"], out_keys=["tokens"])
2542+
td_tokenized = t(td)
2543+
t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"])
2544+
td_recon = t_inv.inv(td_tokenized.clone().exclude("observation"))
2545+
assert td_tokenized.get("observation") is td.get("observation")
2546+
assert td_recon["observation"] == td["observation"]
2547+
2548+
@pytest.mark.parametrize("datatype", ["str"])
2549+
def test_single_trans_env_check(self, datatype):
2550+
if datatype == "str":
2551+
t = Tokenizer(
2552+
in_keys=["string"],
2553+
out_keys=["tokens"],
2554+
max_length=5,
2555+
)
2556+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2557+
env = TransformedEnv(base_env, t)
2558+
check_env_specs(env, return_contiguous=False)
2559+
2560+
@pytest.mark.parametrize("datatype", ["str"])
2561+
def test_serial_trans_env_check(self, datatype):
2562+
def make_env():
2563+
if datatype == "str":
2564+
t = Tokenizer(
2565+
in_keys=["string"],
2566+
out_keys=["tokens"],
2567+
max_length=5,
2568+
)
2569+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2570+
2571+
return TransformedEnv(base_env, t)
2572+
2573+
env = SerialEnv(2, make_env)
2574+
check_env_specs(env, return_contiguous=False)
2575+
2576+
@pytest.mark.parametrize("datatype", ["str"])
2577+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
2578+
def make_env():
2579+
if datatype == "str":
2580+
t = Tokenizer(
2581+
in_keys=["string"],
2582+
out_keys=["tokens"],
2583+
max_length=5,
2584+
)
2585+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2586+
return TransformedEnv(base_env, t)
2587+
2588+
env = maybe_fork_ParallelEnv(2, make_env)
2589+
try:
2590+
check_env_specs(env, return_contiguous=False)
2591+
finally:
2592+
try:
2593+
env.close()
2594+
except RuntimeError:
2595+
pass
2596+
2597+
@pytest.mark.parametrize("datatype", ["str"])
2598+
def test_trans_serial_env_check(self, datatype):
2599+
if datatype == "str":
2600+
t = Tokenizer(
2601+
in_keys=["string"],
2602+
out_keys=["tokens"],
2603+
max_length=5,
2604+
)
2605+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2606+
2607+
env = TransformedEnv(SerialEnv(2, base_env), t)
2608+
check_env_specs(env, return_contiguous=False)
2609+
2610+
@pytest.mark.parametrize("datatype", ["str"])
2611+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
2612+
if datatype == "str":
2613+
t = Tokenizer(
2614+
in_keys=["string"],
2615+
out_keys=["tokens"],
2616+
max_length=5,
2617+
)
2618+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2619+
2620+
env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
2621+
try:
2622+
check_env_specs(env, return_contiguous=False)
2623+
finally:
2624+
try:
2625+
env.close()
2626+
except RuntimeError:
2627+
pass
2628+
2629+
@pytest.mark.parametrize("datatype", ["str"])
2630+
def test_transform_compose(self, datatype):
2631+
if datatype == "str":
2632+
obs = "abcdefg"
2633+
2634+
td = TensorDict(
2635+
{
2636+
"observation": obs,
2637+
}
2638+
)
2639+
t = Tokenizer(
2640+
in_keys=["observation"],
2641+
out_keys=["tokens"],
2642+
max_length=5,
2643+
)
2644+
t = Compose(t)
2645+
td_tokenized = t(td)
2646+
2647+
assert td_tokenized["observation"] is td["observation"]
2648+
assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt")
2649+
2650+
# TODO
2651+
def test_transform_model(self):
2652+
t = Hash(
2653+
in_keys=[("next", "observation"), ("observation",)],
2654+
out_keys=[("next", "hashing"), ("hashing",)],
2655+
hash_fn=hash,
2656+
)
2657+
model = nn.Sequential(t, nn.Identity())
2658+
td = TensorDict(
2659+
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
2660+
)
2661+
td_out = model(td)
2662+
assert ("next", "hashing") in td_out.keys(True)
2663+
assert ("hashing",) in td_out.keys(True)
2664+
assert td_out["next", "hashing"] == hash(td["next", "observation"])
2665+
assert td_out["hashing"] == hash(td["observation"])
2666+
2667+
@pytest.mark.skipif(not _has_gym, reason="Gym not found")
2668+
def test_transform_env(self):
2669+
t = Hash(
2670+
in_keys=["observation"],
2671+
out_keys=["hashing"],
2672+
hash_fn=hash,
2673+
)
2674+
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
2675+
assert env.observation_spec["hashing"]
2676+
assert "observation" in env.observation_spec
2677+
assert "observation" in env.base_env.observation_spec
2678+
check_env_specs(env)
2679+
2680+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
2681+
def test_transform_rb(self, rbclass):
2682+
t = Hash(
2683+
in_keys=[("next", "observation"), ("observation",)],
2684+
out_keys=[("next", "hashing"), ("hashing",)],
2685+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2686+
)
2687+
rb = rbclass(storage=LazyTensorStorage(10))
2688+
rb.append_transform(t)
2689+
td = TensorDict(
2690+
{
2691+
"observation": torch.randn(3, 4),
2692+
"next": TensorDict(
2693+
{"observation": torch.randn(3, 4)},
2694+
[],
2695+
),
2696+
},
2697+
[],
2698+
).expand(10)
2699+
rb.extend(td)
2700+
td = rb.sample(2)
2701+
assert "hashing" in td.keys()
2702+
assert "observation" in td.keys()
2703+
assert ("next", "observation") in td.keys(True)
2704+
2705+
def test_transform_inverse(self):
2706+
env = CountingEnv()
2707+
env = env.append_transform(
2708+
Hash(
2709+
in_keys=[],
2710+
out_keys=[],
2711+
in_keys_inv=["action"],
2712+
out_keys_inv=["action_hash"],
2713+
)
2714+
)
2715+
assert "action_hash" in env.action_keys
2716+
r = env.rollout(3)
2717+
env.check_env_specs()
2718+
assert "action_hash" in r
2719+
assert isinstance(r[0]["action_hash"], torch.Tensor)
25032720

25042721

25052722
class TestStack(TransformBase):

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
TargetReturn,
9595
TensorDictPrimer,
9696
TimeMaxPool,
97+
Tokenizer,
9798
ToTensorImage,
9899
TrajCounter,
99100
Transform,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TargetReturn,
5656
TensorDictPrimer,
5757
TimeMaxPool,
58+
Tokenizer,
5859
ToTensorImage,
5960
TrajCounter,
6061
Transform,

0 commit comments

Comments
 (0)