Skip to content

Commit b67b3ef

Browse files
committed
[Feature] Env with tensorclass attributes
ghstack-source-id: a3b2a95 Pull Request resolved: #2788
1 parent 1e548a7 commit b67b3ef

File tree

5 files changed

+169
-58
lines changed

5 files changed

+169
-58
lines changed

test/mocking_classes.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from tensordict import TensorDict, TensorDictBase
13+
from tensordict import tensorclass, TensorDict, TensorDictBase
1414
from tensordict.nn import TensorDictModuleBase
1515
from tensordict.utils import expand_right, NestedKey
1616

17-
from torchrl.data.tensor_specs import (
17+
from torchrl.data import (
1818
Binary,
1919
Bounded,
2020
Categorical,
@@ -2356,3 +2356,42 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:
23562356

23572357
def _set_seed(self, seed: Optional[int]):
23582358
...
2359+
2360+
2361+
@tensorclass()
2362+
class TC:
2363+
field0: str
2364+
field1: torch.Tensor
2365+
2366+
2367+
class EnvWithTensorClass(CountingEnv):
2368+
tc_cls = TC
2369+
2370+
def __init__(self, **kwargs):
2371+
super().__init__(**kwargs)
2372+
self.observation_spec["tc"] = Composite(
2373+
field0=NonTensor(example_data="an observation!", shape=self.batch_size),
2374+
field1=Unbounded(shape=self.batch_size),
2375+
shape=self.batch_size,
2376+
data_cls=TC,
2377+
)
2378+
2379+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2380+
td = super()._reset(tensordict, **kwargs)
2381+
td["tc"] = TC("0", torch.zeros(self.batch_size))
2382+
return td
2383+
2384+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2385+
td = super()._step(tensordict, **kwargs)
2386+
default = TC("0", 0)
2387+
f0 = tensordict.get("tc", default).field0
2388+
if f0 is None:
2389+
f0 = "0"
2390+
f1 = tensordict.get("tc", default).field1
2391+
if f1 is None:
2392+
f1 = torch.zeros(self.batch_size)
2393+
td["tc"] = TC(
2394+
str(int(f0) + 1),
2395+
f1 + 1,
2396+
)
2397+
return td

test/test_env.py

+25
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
EnvThatDoesNothing,
128128
EnvWithDynamicSpec,
129129
EnvWithMetadata,
130+
EnvWithTensorClass,
130131
HeterogeneousCountingEnv,
131132
HeterogeneousCountingEnvPolicy,
132133
MockBatchedLockedEnv,
@@ -166,6 +167,7 @@
166167
EnvThatDoesNothing,
167168
EnvWithDynamicSpec,
168169
EnvWithMetadata,
170+
EnvWithTensorClass,
169171
HeterogeneousCountingEnv,
170172
HeterogeneousCountingEnvPolicy,
171173
MockBatchedLockedEnv,
@@ -3708,6 +3710,29 @@ def test_str2str_rb_slicesampler(self):
37083710
else:
37093711
raise RuntimeError("Failed to sample both trajs")
37103712

3713+
def test_env_with_tensorclass(self):
3714+
env = EnvWithTensorClass()
3715+
env.check_env_specs()
3716+
r = env.reset()
3717+
for _ in range(3):
3718+
assert isinstance(r["tc"], env.tc_cls)
3719+
a = env.rand_action(r)
3720+
s = env.step(a)
3721+
assert isinstance(s["tc"], env.tc_cls)
3722+
r = env.step_mdp(s)
3723+
3724+
@pytest.mark.parametrize("cls", [SerialEnv, ParallelEnv])
3725+
def test_env_with_tensorclass_batched(self, cls):
3726+
env = cls(2, EnvWithTensorClass)
3727+
env.check_env_specs()
3728+
r = env.reset()
3729+
for _ in range(3):
3730+
assert isinstance(r["tc"], EnvWithTensorClass.tc_cls)
3731+
a = env.rand_action(r)
3732+
s = env.step(a)
3733+
assert isinstance(s["tc"], EnvWithTensorClass.tc_cls)
3734+
r = env.step_mdp(s)
3735+
37113736

37123737
# fen strings for board positions generated with:
37133738
# https://lichess.org/editor

0 commit comments

Comments
 (0)