Skip to content

Commit

Permalink
[Feature] Env with tensorclass attributes
Browse files Browse the repository at this point in the history
ghstack-source-id: dc00ea3d23e015756974cd5c2ce638b55e5f6f92
Pull Request resolved: #2788
  • Loading branch information
vmoens committed Feb 13, 2025
1 parent e084c02 commit ab76027
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 58 deletions.
43 changes: 41 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import torch
import torch.nn as nn
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import expand_right, NestedKey

from torchrl.data.tensor_specs import (
from torchrl.data import (
Binary,
Bounded,
Categorical,
Expand Down Expand Up @@ -2356,3 +2356,42 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict:

def _set_seed(self, seed: Optional[int]):
...


@tensorclass()
class TC:
field0: str
field1: torch.Tensor


class EnvWithTensorClass(CountingEnv):
tc_cls = TC

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.observation_spec["tc"] = Composite(
field0=NonTensor(example_data="an observation!", shape=self.batch_size),
field1=Unbounded(shape=self.batch_size),
shape=self.batch_size,
data_cls=TC,
)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
td = super()._reset(tensordict, **kwargs)
td["tc"] = TC("0", torch.zeros(self.batch_size))
return td

def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
td = super()._step(tensordict, **kwargs)
default = TC("0", 0)
f0 = tensordict.get("tc", default).field0
if f0 is None:
f0 = "0"
f1 = tensordict.get("tc", default).field1
if f1 is None:
f1 = torch.zeros(self.batch_size)
td["tc"] = TC(
str(int(f0) + 1),
f1 + 1,
)
return td
25 changes: 25 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
EnvWithTensorClass,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -166,6 +167,7 @@
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
EnvWithTensorClass,
HeterogeneousCountingEnv,
HeterogeneousCountingEnvPolicy,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -3708,6 +3710,29 @@ def test_str2str_rb_slicesampler(self):
else:
raise RuntimeError("Failed to sample both trajs")

def test_env_with_tensorclass(self):
env = EnvWithTensorClass()
env.check_env_specs()
r = env.reset()
for _ in range(3):
assert isinstance(r["tc"], env.tc_cls)
a = env.rand_action(r)
s = env.step(a)
assert isinstance(s["tc"], env.tc_cls)
r = env.step_mdp(s)

@pytest.mark.parametrize("cls", [SerialEnv, ParallelEnv])
def test_env_with_tensorclass_batched(self, cls):
env = cls(2, EnvWithTensorClass)
env.check_env_specs()
r = env.reset()
for _ in range(3):
assert isinstance(r["tc"], EnvWithTensorClass.tc_cls)
a = env.rand_action(r)
s = env.step(a)
assert isinstance(s["tc"], EnvWithTensorClass.tc_cls)
r = env.step_mdp(s)


# fen strings for board positions generated with:
# https://lichess.org/editor
Expand Down
Loading

0 comments on commit ab76027

Please sign in to comment.