Skip to content

Commit ec2a0ba

Browse files
committed
[Feature] Env with tensorclass attributes
ghstack-source-id: 1d844e1 Pull Request resolved: #2788
1 parent ab6dadd commit ec2a0ba

File tree

4 files changed

+161
-50
lines changed

4 files changed

+161
-50
lines changed

test/mocking_classes.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from tensordict import TensorDict, TensorDictBase
13+
from tensordict import tensorclass, TensorDict, TensorDictBase
1414
from tensordict.nn import TensorDictModuleBase
1515
from tensordict.utils import expand_right, NestedKey
1616

17-
from torchrl.data.tensor_specs import (
17+
from torchrl.data import (
1818
Binary,
1919
Bounded,
2020
Categorical,
@@ -2349,3 +2349,42 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
23492349

23502350
def _set_seed(self, seed: Optional[int]):
23512351
...
2352+
2353+
2354+
@tensorclass()
2355+
class TC:
2356+
field0: str
2357+
field1: torch.Tensor
2358+
2359+
2360+
class EnvWithTensorClass(CountingEnv):
2361+
tc_cls = TC
2362+
2363+
def __init__(self, **kwargs):
2364+
super().__init__(**kwargs)
2365+
self.observation_spec["tc"] = Composite(
2366+
field0=NonTensor(example_data="an observation!", shape=self.batch_size),
2367+
field1=Unbounded(shape=self.batch_size),
2368+
shape=self.batch_size,
2369+
data_cls=TC,
2370+
)
2371+
2372+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2373+
td = super()._reset(tensordict, **kwargs)
2374+
td["tc"] = TC("0", torch.zeros(self.batch_size))
2375+
return td
2376+
2377+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2378+
td = super()._step(tensordict, **kwargs)
2379+
default = TC("0", 0)
2380+
f0 = tensordict.get("tc", default).field0
2381+
if f0 is None:
2382+
f0 = "0"
2383+
f1 = tensordict.get("tc", default).field1
2384+
if f1 is None:
2385+
f1 = torch.zeros(self.batch_size)
2386+
td["tc"] = TC(
2387+
str(int(f0) + 1),
2388+
f1 + 1,
2389+
)
2390+
return td

test/test_env.py

+25
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
EnvThatDoesNothing,
128128
EnvWithDynamicSpec,
129129
EnvWithMetadata,
130+
EnvWithTensorClass,
130131
HeterogeneousCountingEnv,
131132
HeterogeneousCountingEnvPolicy,
132133
MockBatchedLockedEnv,
@@ -166,6 +167,7 @@
166167
EnvThatDoesNothing,
167168
EnvWithDynamicSpec,
168169
EnvWithMetadata,
170+
EnvWithTensorClass,
169171
HeterogeneousCountingEnv,
170172
HeterogeneousCountingEnvPolicy,
171173
MockBatchedLockedEnv,
@@ -3707,6 +3709,29 @@ def test_str2str_rb_slicesampler(self):
37073709
else:
37083710
raise RuntimeError("Failed to sample both trajs")
37093711

3712+
def test_env_with_tensorclass(self):
3713+
env = EnvWithTensorClass()
3714+
env.check_env_specs()
3715+
r = env.reset()
3716+
for _ in range(3):
3717+
assert isinstance(r["tc"], env.tc_cls)
3718+
a = env.rand_action(r)
3719+
s = env.step(a)
3720+
assert isinstance(s["tc"], env.tc_cls)
3721+
r = env.step_mdp(s)
3722+
3723+
@pytest.mark.parametrize("cls", [SerialEnv, ParallelEnv])
3724+
def test_env_with_tensorclass_batched(self, cls):
3725+
env = cls(2, EnvWithTensorClass)
3726+
env.check_env_specs()
3727+
r = env.reset()
3728+
for _ in range(3):
3729+
assert isinstance(r["tc"], EnvWithTensorClass.tc_cls)
3730+
a = env.rand_action(r)
3731+
s = env.step(a)
3732+
assert isinstance(s["tc"], EnvWithTensorClass.tc_cls)
3733+
r = env.step_mdp(s)
3734+
37103735

37113736
# fen strings for board positions generated with:
37123737
# https://lichess.org/editor

0 commit comments

Comments
 (0)