Skip to content

Commit a839bc6

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents 63fca58 + 413571b commit a839bc6

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

test/test_env.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import defaultdict
1717
from functools import partial
1818
from sys import platform
19-
from typing import Optional
19+
from typing import Any, Optional
2020

2121
import numpy as np
2222
import pytest
@@ -33,7 +33,7 @@
3333
TensorDictBase,
3434
)
3535
from tensordict.nn import TensorDictModuleBase
36-
from tensordict.tensorclass import NonTensorStack
36+
from tensordict.tensorclass import NonTensorStack, TensorClass
3737
from tensordict.utils import _unravel_key_to_tuple
3838
from torch import nn
3939

@@ -340,7 +340,8 @@ def forward(self, values):
340340
)
341341
env.rollout(10, policy)
342342

343-
def test_make_spec_from_td(self):
343+
@pytest.mark.parametrize("dynamic_shape", [True, False])
344+
def test_make_spec_from_td(self, dynamic_shape):
344345
data = TensorDict(
345346
{
346347
"obs": torch.randn(3),
@@ -353,10 +354,44 @@ def test_make_spec_from_td(self):
353354
},
354355
[],
355356
)
356-
spec = make_composite_from_td(data)
357+
spec = make_composite_from_td(data, dynamic_shape=dynamic_shape)
357358
assert (spec.zero() == data.zero_()).all()
358359
for key, val in data.items(True, True):
359360
assert val.dtype is spec[key].dtype
361+
if dynamic_shape:
362+
assert all(s.shape[-1] == -1 for s in spec.values(True, True))
363+
364+
def test_make_spec_from_tc(self):
365+
class Scratch(TensorClass):
366+
obs: torch.Tensor
367+
string: str
368+
some_object: Any
369+
370+
class Whatever:
371+
...
372+
373+
td = TensorDict(
374+
a=Scratch(
375+
obs=torch.ones(5, 3),
376+
string="another string!",
377+
some_object=Whatever(),
378+
batch_size=(5,),
379+
),
380+
b="a string!",
381+
batch_size=(5,),
382+
)
383+
spec = make_composite_from_td(td)
384+
assert isinstance(spec, Composite)
385+
assert isinstance(spec["a"], Composite)
386+
assert isinstance(spec["b"], NonTensor)
387+
assert spec["b"].example_data == "a string!", spec["b"].example_data
388+
assert spec["a", "string"].example_data == "another string!"
389+
one = spec.one()
390+
assert isinstance(one["a"], Scratch)
391+
assert isinstance(one["b"], str)
392+
assert isinstance(one["a"].string, str)
393+
assert isinstance(one["a"].some_object, Whatever)
394+
assert (one == td).all()
360395

361396
def test_env_that_does_nothing(self):
362397
env = EnvThatDoesNothing()

torchrl/envs/utils.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -888,13 +888,19 @@ def _sort_keys(element):
888888
return element
889889

890890

891-
def make_composite_from_td(data, unsqueeze_null_shapes: bool = True):
891+
def make_composite_from_td(
892+
data, *, unsqueeze_null_shapes: bool = True, dynamic_shape: bool = False
893+
):
892894
"""Creates a Composite instance from a tensordict, assuming all values are unbounded.
893895
894896
Args:
895897
data (tensordict.TensorDict): a tensordict to be mapped onto a Composite.
898+
899+
Keyword Args:
896900
unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be
897901
unsqueezed to (1,). Defaults to ``True``.
902+
dynamic_shape (bool, optional): if ``True``, all tensors will be assumed to have a dynamic shape
903+
along the last dimension. Defaults to ``False``.
898904
899905
Examples:
900906
>>> from tensordict import TensorDict
@@ -919,22 +925,33 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True):
919925
"""
920926
# custom function to convert a tensordict in a similar spec structure
921927
# of unbounded values.
928+
def make_shape(shape):
929+
if shape or not unsqueeze_null_shapes:
930+
if dynamic_shape:
931+
return shape[:-1] + (-1,)
932+
else:
933+
return shape
934+
return torch.Size([1])
935+
922936
composite = Composite(
923937
{
924-
key: make_composite_from_td(tensor)
925-
if isinstance(tensor, TensorDictBase)
926-
else NonTensor(shape=data.shape, device=tensor.device)
938+
key: make_composite_from_td(
939+
tensor,
940+
unsqueeze_null_shapes=unsqueeze_null_shapes,
941+
dynamic_shape=dynamic_shape,
942+
)
943+
if is_tensor_collection(tensor) and not is_non_tensor(tensor)
944+
else NonTensor(
945+
shape=tensor.shape, example_data=tensor.data, device=tensor.device
946+
)
927947
if is_non_tensor(tensor)
928948
else Unbounded(
929-
dtype=tensor.dtype,
930-
device=tensor.device,
931-
shape=tensor.shape
932-
if tensor.shape or not unsqueeze_null_shapes
933-
else [1],
949+
dtype=tensor.dtype, device=tensor.device, shape=make_shape(tensor.shape)
934950
)
935951
for key, tensor in data.items()
936952
},
937953
shape=data.shape,
954+
data_cls=type(data),
938955
)
939956
return composite
940957

0 commit comments

Comments
 (0)