Skip to content

Commit 6b5501a

Browse files
committed
[Feature] example_data for NonTensor spec
ghstack-source-id: 694fb3e559b51cb7260894018d6e5b0214050690 Pull Request resolved: #2698
1 parent 256a700 commit 6b5501a

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

test/test_specs.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
14021402
assert spec2.zero().shape == spec2.shape
14031403

14041404
def test_non_tensor(self):
1405-
spec = NonTensor((3, 4), device="cpu")
1405+
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
14061406
assert (
14071407
spec.expand(2, 3, 4)
14081408
== spec.expand((2, 3, 4))
1409-
== NonTensor((2, 3, 4), device="cpu")
1409+
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
14101410
)
1411+
assert spec.expand(2, 3, 4).example_data == "example_data"
14111412

14121413
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14131414
@pytest.mark.parametrize("shape2", [(), (10,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
16071608
assert spec is not spec.clone()
16081609

16091610
def test_non_tensor(self):
1610-
spec = NonTensor(shape=(3, 4), device="cpu")
1611+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
16111612
assert spec.clone() == spec
16121613
assert spec.clone() is not spec
1614+
assert spec.clone().example_data == "example_data"
16131615

16141616
@pytest.mark.parametrize("shape1", [None, (), (5,)])
16151617
def test_onehot(
@@ -1840,9 +1842,10 @@ def test_multionehot(
18401842
spec.unbind(-1)
18411843

18421844
def test_non_tensor(self):
1843-
spec = NonTensor(shape=(3, 4), device="cpu")
1845+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
18441846
assert spec.unbind(1)[0] == spec[:, 0]
18451847
assert spec.unbind(1)[0] is not spec[:, 0]
1848+
assert spec.unbind(1)[0].example_data == "example_data"
18461849

18471850
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
18481851
def test_onehot(
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
20012004
assert spec.to(device).device == device
20022005

20032006
def test_non_tensor(self, device):
2004-
spec = NonTensor(shape=(3, 4), device="cpu")
2007+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
20052008
assert spec.to(device).device == device
2009+
assert spec.to(device).example_data == "example_data"
20062010

20072011
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
20082012
def test_onehot(self, shape1, device):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
22622266
assert r.shape == c.shape
22632267

22642268
def test_stack_non_tensor(self, shape, stack_dim):
2265-
spec0 = NonTensor(shape=shape, device="cpu")
2266-
spec1 = NonTensor(shape=shape, device="cpu")
2269+
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
2270+
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
22672271
new_spec = torch.stack([spec0, spec1], stack_dim)
22682272
shape_insert = list(shape)
22692273
shape_insert.insert(stack_dim, 2)
22702274
assert new_spec.shape == torch.Size(shape_insert)
22712275
assert new_spec.device == torch.device("cpu")
2276+
assert new_spec.example_data == "example_data"
22722277

22732278
def test_stack_onehot(self, shape, stack_dim):
22742279
n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
36423647

36433648
class TestNonTensorSpec:
36443649
def test_sample(self):
3645-
nts = NonTensor(shape=(3, 4))
3650+
nts = NonTensor(shape=(3, 4), example_data="example_data")
36463651
assert nts.one((2,)).shape == (2, 3, 4)
36473652
assert nts.rand((2,)).shape == (2, 3, 4)
36483653
assert nts.zero((2,)).shape == (2, 3, 4)
3654+
assert nts.one((2,)).data == "example_data"
3655+
assert nts.rand((2,)).data == "example_data"
3656+
assert nts.zero((2,)).data == "example_data"
3657+
3658+
def test_example_data_ineq(self):
3659+
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
3660+
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
3661+
assert nts0 != nts1
36493662

36503663

36513664
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")

torchrl/data/tensor_specs.py

+57-8
Original file line numberDiff line numberDiff line change
@@ -2452,11 +2452,14 @@ class NonTensor(TensorSpec):
24522452
(same will go for :meth:`.zero` and :meth:`.one`).
24532453
"""
24542454

2455+
example_data: Any = None
2456+
24552457
def __init__(
24562458
self,
24572459
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
24582460
device: Optional[DEVICE_TYPING] = None,
24592461
dtype: torch.dtype | None = None,
2462+
example_data: Any = None,
24602463
**kwargs,
24612464
):
24622465
if isinstance(shape, int):
@@ -2467,6 +2470,12 @@ def __init__(
24672470
super().__init__(
24682471
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
24692472
)
2473+
self.example_data = example_data
2474+
2475+
def __eq__(self, other):
2476+
eq = super().__eq__(other)
2477+
eq = eq & (self.example_data == getattr(other, "example_data", None))
2478+
return eq
24702479

24712480
def cardinality(self) -> Any:
24722481
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
@@ -2485,30 +2494,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24852494
dest_device = torch.device(dest)
24862495
if dest_device == self.device and dest_dtype == self.dtype:
24872496
return self
2488-
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
2497+
return self.__class__(
2498+
shape=self.shape,
2499+
device=dest_device,
2500+
dtype=None,
2501+
example_data=self.example_data,
2502+
)
24892503

24902504
def clone(self) -> NonTensor:
2491-
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
2505+
return self.__class__(
2506+
shape=self.shape,
2507+
device=self.device,
2508+
dtype=self.dtype,
2509+
example_data=self.example_data,
2510+
)
24922511

24932512
def rand(self, shape=None):
24942513
if shape is None:
24952514
shape = ()
24962515
return NonTensorData(
2497-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2516+
data=self.example_data,
2517+
batch_size=(*shape, *self._safe_shape),
2518+
device=self.device,
24982519
)
24992520

25002521
def zero(self, shape=None):
25012522
if shape is None:
25022523
shape = ()
25032524
return NonTensorData(
2504-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2525+
data=self.example_data,
2526+
batch_size=(*shape, *self._safe_shape),
2527+
device=self.device,
25052528
)
25062529

25072530
def one(self, shape=None):
25082531
if shape is None:
25092532
shape = ()
25102533
return NonTensorData(
2511-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2534+
data=self.example_data,
2535+
batch_size=(*shape, *self._safe_shape),
2536+
device=self.device,
25122537
)
25132538

25142539
def is_in(self, val: Any) -> bool:
@@ -2533,23 +2558,46 @@ def expand(self, *shape):
25332558
raise ValueError(
25342559
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
25352560
)
2536-
return self.__class__(shape=shape, device=self.device, dtype=None)
2561+
return self.__class__(
2562+
shape=shape, device=self.device, dtype=None, example_data=self.example_data
2563+
)
2564+
2565+
def unsqueeze(self, dim: int) -> NonTensor:
2566+
unsq = super().unsqueeze(dim=dim)
2567+
unsq.example_data = self.example_data
2568+
return unsq
2569+
2570+
def squeeze(self, dim: int | None = None) -> NonTensor:
2571+
sq = super().squeeze(dim=dim)
2572+
sq.example_data = self.example_data
2573+
return sq
25372574

25382575
def _reshape(self, shape):
2539-
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
2576+
return self.__class__(
2577+
shape=shape,
2578+
device=self.device,
2579+
dtype=self.dtype,
2580+
example_data=self.example_data,
2581+
)
25402582

25412583
def _unflatten(self, dim, sizes):
25422584
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
25432585
return self.__class__(
25442586
shape=shape,
25452587
device=self.device,
25462588
dtype=self.dtype,
2589+
example_data=self.example_data,
25472590
)
25482591

25492592
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
25502593
"""Indexes the current TensorSpec based on the provided index."""
25512594
indexed_shape = _size(_shape_indexing(self.shape, idx))
2552-
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
2595+
return self.__class__(
2596+
shape=indexed_shape,
2597+
device=self.device,
2598+
dtype=self.dtype,
2599+
example_data=self.example_data,
2600+
)
25532601

25542602
def unbind(self, dim: int = 0):
25552603
orig_dim = dim
@@ -2565,6 +2613,7 @@ def unbind(self, dim: int = 0):
25652613
shape=shape,
25662614
device=self.device,
25672615
dtype=self.dtype,
2616+
example_data=self.example_data,
25682617
)
25692618
for i in range(self.shape[dim])
25702619
)

0 commit comments

Comments
 (0)