Skip to content

Commit 3343b29

Browse files
committed
[Feature] UnaryTransform for input entries
ghstack-source-id: 3d63a2b1f44cb2ae652def00f35f3a3cfde1756b Pull Request resolved: #2700
1 parent ddd96fb commit 3343b29

File tree

5 files changed

+468
-13
lines changed

5 files changed

+468
-13
lines changed

test/mocking_classes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,20 @@ def _step(
10701070

10711071
class CountingEnvWithString(CountingEnv):
10721072
def __init__(self, *args, **kwargs):
1073+
self.max_size = kwargs.pop("max_size", 30)
1074+
self.min_size = kwargs.pop("min_size", 4)
10731075
super().__init__(*args, **kwargs)
10741076
self.observation_spec.set(
10751077
"string",
10761078
NonTensor(
10771079
shape=self.batch_size,
10781080
device=self.device,
1081+
example_data=self.get_random_string(),
10791082
),
10801083
)
10811084

10821085
def get_random_string(self):
1083-
size = random.randint(4, 30)
1086+
size = random.randint(self.min_size, self.max_size)
10841087
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
10851088

10861089
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

test/test_transforms.py

+218-1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
TargetReturn,
148148
TensorDictPrimer,
149149
TimeMaxPool,
150+
Tokenizer,
150151
ToTensorImage,
151152
TrajCounter,
152153
TransformedEnv,
@@ -2420,7 +2421,223 @@ def test_transform_rb(self, rbclass):
24202421
assert ("next", "observation") in td.keys(True)
24212422

24222423
def test_transform_inverse(self):
2423-
raise pytest.skip("No inverse for Hash")
2424+
env = CountingEnv()
2425+
env = env.append_transform(
2426+
Hash(
2427+
in_keys=[],
2428+
out_keys=[],
2429+
in_keys_inv=["action"],
2430+
out_keys_inv=["action_hash"],
2431+
)
2432+
)
2433+
assert "action_hash" in env.action_keys
2434+
r = env.rollout(3)
2435+
env.check_env_specs()
2436+
assert "action_hash" in r
2437+
assert isinstance(r[0]["action_hash"], torch.Tensor)
2438+
2439+
2440+
class TestTokenizer(TransformBase):
2441+
@pytest.mark.parametrize("datatype", ["str", "NonTensorStack"])
2442+
def test_transform_no_env(self, datatype):
2443+
if datatype == "str":
2444+
obs = "abcdefg"
2445+
elif datatype == "NonTensorStack":
2446+
obs = torch.stack(
2447+
[
2448+
NonTensorData(data="abcde"),
2449+
NonTensorData(data="fghij"),
2450+
NonTensorData(data="klmno"),
2451+
]
2452+
)
2453+
else:
2454+
raise RuntimeError(f"please add a test case for datatype {datatype}")
2455+
2456+
td = TensorDict(
2457+
{
2458+
"observation": obs,
2459+
}
2460+
)
2461+
2462+
t = Tokenizer(in_keys=["observation"], out_keys=["tokens"])
2463+
td_tokenized = t(td)
2464+
t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"])
2465+
td_recon = t_inv.inv(td_tokenized.clone().exclude("observation"))
2466+
assert td_tokenized.get("observation") is td.get("observation")
2467+
assert td_recon["observation"] == td["observation"]
2468+
2469+
@pytest.mark.parametrize("datatype", ["str"])
2470+
def test_single_trans_env_check(self, datatype):
2471+
if datatype == "str":
2472+
t = Tokenizer(
2473+
in_keys=["string"],
2474+
out_keys=["tokens"],
2475+
max_length=5,
2476+
)
2477+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2478+
env = TransformedEnv(base_env, t)
2479+
check_env_specs(env, return_contiguous=False)
2480+
2481+
@pytest.mark.parametrize("datatype", ["str"])
2482+
def test_serial_trans_env_check(self, datatype):
2483+
def make_env():
2484+
if datatype == "str":
2485+
t = Tokenizer(
2486+
in_keys=["string"],
2487+
out_keys=["tokens"],
2488+
max_length=5,
2489+
)
2490+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2491+
2492+
return TransformedEnv(base_env, t)
2493+
2494+
env = SerialEnv(2, make_env)
2495+
check_env_specs(env, return_contiguous=False)
2496+
2497+
@pytest.mark.parametrize("datatype", ["str"])
2498+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
2499+
def make_env():
2500+
if datatype == "str":
2501+
t = Tokenizer(
2502+
in_keys=["string"],
2503+
out_keys=["tokens"],
2504+
max_length=5,
2505+
)
2506+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2507+
return TransformedEnv(base_env, t)
2508+
2509+
env = maybe_fork_ParallelEnv(2, make_env)
2510+
try:
2511+
check_env_specs(env, return_contiguous=False)
2512+
finally:
2513+
try:
2514+
env.close()
2515+
except RuntimeError:
2516+
pass
2517+
2518+
@pytest.mark.parametrize("datatype", ["str"])
2519+
def test_trans_serial_env_check(self, datatype):
2520+
if datatype == "str":
2521+
t = Tokenizer(
2522+
in_keys=["string"],
2523+
out_keys=["tokens"],
2524+
max_length=5,
2525+
)
2526+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2527+
2528+
env = TransformedEnv(SerialEnv(2, base_env), t)
2529+
check_env_specs(env, return_contiguous=False)
2530+
2531+
@pytest.mark.parametrize("datatype", ["str"])
2532+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
2533+
if datatype == "str":
2534+
t = Tokenizer(
2535+
in_keys=["string"],
2536+
out_keys=["tokens"],
2537+
max_length=5,
2538+
)
2539+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2540+
2541+
env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
2542+
try:
2543+
check_env_specs(env, return_contiguous=False)
2544+
finally:
2545+
try:
2546+
env.close()
2547+
except RuntimeError:
2548+
pass
2549+
2550+
@pytest.mark.parametrize("datatype", ["str"])
2551+
def test_transform_compose(self, datatype):
2552+
if datatype == "str":
2553+
obs = "abcdefg"
2554+
2555+
td = TensorDict(
2556+
{
2557+
"observation": obs,
2558+
}
2559+
)
2560+
t = Tokenizer(
2561+
in_keys=["observation"],
2562+
out_keys=["tokens"],
2563+
max_length=5,
2564+
)
2565+
t = Compose(t)
2566+
td_tokenized = t(td)
2567+
2568+
assert td_tokenized["observation"] is td["observation"]
2569+
assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt")
2570+
2571+
# TODO
2572+
def test_transform_model(self):
2573+
t = Hash(
2574+
in_keys=[("next", "observation"), ("observation",)],
2575+
out_keys=[("next", "hashing"), ("hashing",)],
2576+
hash_fn=hash,
2577+
)
2578+
model = nn.Sequential(t, nn.Identity())
2579+
td = TensorDict(
2580+
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
2581+
)
2582+
td_out = model(td)
2583+
assert ("next", "hashing") in td_out.keys(True)
2584+
assert ("hashing",) in td_out.keys(True)
2585+
assert td_out["next", "hashing"] == hash(td["next", "observation"])
2586+
assert td_out["hashing"] == hash(td["observation"])
2587+
2588+
@pytest.mark.skipif(not _has_gym, reason="Gym not found")
2589+
def test_transform_env(self):
2590+
t = Hash(
2591+
in_keys=["observation"],
2592+
out_keys=["hashing"],
2593+
hash_fn=hash,
2594+
)
2595+
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
2596+
assert env.observation_spec["hashing"]
2597+
assert "observation" in env.observation_spec
2598+
assert "observation" in env.base_env.observation_spec
2599+
check_env_specs(env)
2600+
2601+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
2602+
def test_transform_rb(self, rbclass):
2603+
t = Hash(
2604+
in_keys=[("next", "observation"), ("observation",)],
2605+
out_keys=[("next", "hashing"), ("hashing",)],
2606+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2607+
)
2608+
rb = rbclass(storage=LazyTensorStorage(10))
2609+
rb.append_transform(t)
2610+
td = TensorDict(
2611+
{
2612+
"observation": torch.randn(3, 4),
2613+
"next": TensorDict(
2614+
{"observation": torch.randn(3, 4)},
2615+
[],
2616+
),
2617+
},
2618+
[],
2619+
).expand(10)
2620+
rb.extend(td)
2621+
td = rb.sample(2)
2622+
assert "hashing" in td.keys()
2623+
assert "observation" in td.keys()
2624+
assert ("next", "observation") in td.keys(True)
2625+
2626+
def test_transform_inverse(self):
2627+
env = CountingEnv()
2628+
env = env.append_transform(
2629+
Hash(
2630+
in_keys=[],
2631+
out_keys=[],
2632+
in_keys_inv=["action"],
2633+
out_keys_inv=["action_hash"],
2634+
)
2635+
)
2636+
assert "action_hash" in env.action_keys
2637+
r = env.rollout(3)
2638+
env.check_env_specs()
2639+
assert "action_hash" in r
2640+
assert isinstance(r[0]["action_hash"], torch.Tensor)
24242641

24252642

24262643
class TestStack(TransformBase):

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
TargetReturn,
9595
TensorDictPrimer,
9696
TimeMaxPool,
97+
Tokenizer,
9798
ToTensorImage,
9899
TrajCounter,
99100
Transform,

torchrl/envs/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TargetReturn,
5656
TensorDictPrimer,
5757
TimeMaxPool,
58+
Tokenizer,
5859
ToTensorImage,
5960
TrajCounter,
6061
Transform,

0 commit comments

Comments
 (0)