16
16
from collections import defaultdict
17
17
from functools import partial
18
18
from sys import platform
19
- from typing import Optional
19
+ from typing import Any , Optional
20
20
21
21
import numpy as np
22
22
import pytest
33
33
TensorDictBase ,
34
34
)
35
35
from tensordict .nn import TensorDictModuleBase
36
- from tensordict .tensorclass import NonTensorStack
36
+ from tensordict .tensorclass import NonTensorStack , TensorClass
37
37
from tensordict .utils import _unravel_key_to_tuple
38
38
from torch import nn
39
39
@@ -340,7 +340,8 @@ def forward(self, values):
340
340
)
341
341
env .rollout (10 , policy )
342
342
343
- def test_make_spec_from_td (self ):
343
+ @pytest .mark .parametrize ("dynamic_shape" , [True , False ])
344
+ def test_make_spec_from_td (self , dynamic_shape ):
344
345
data = TensorDict (
345
346
{
346
347
"obs" : torch .randn (3 ),
@@ -353,10 +354,44 @@ def test_make_spec_from_td(self):
353
354
},
354
355
[],
355
356
)
356
- spec = make_composite_from_td (data )
357
+ spec = make_composite_from_td (data , dynamic_shape = dynamic_shape )
357
358
assert (spec .zero () == data .zero_ ()).all ()
358
359
for key , val in data .items (True , True ):
359
360
assert val .dtype is spec [key ].dtype
361
+ if dynamic_shape :
362
+ assert all (s .shape [- 1 ] == - 1 for s in spec .values (True , True ))
363
+
364
+ def test_make_spec_from_tc (self ):
365
+ class Scratch (TensorClass ):
366
+ obs : torch .Tensor
367
+ string : str
368
+ some_object : Any
369
+
370
+ class Whatever :
371
+ ...
372
+
373
+ td = TensorDict (
374
+ a = Scratch (
375
+ obs = torch .ones (5 , 3 ),
376
+ string = "another string!" ,
377
+ some_object = Whatever (),
378
+ batch_size = (5 ,),
379
+ ),
380
+ b = "a string!" ,
381
+ batch_size = (5 ,),
382
+ )
383
+ spec = make_composite_from_td (td )
384
+ assert isinstance (spec , Composite )
385
+ assert isinstance (spec ["a" ], Composite )
386
+ assert isinstance (spec ["b" ], NonTensor )
387
+ assert spec ["b" ].example_data == "a string!" , spec ["b" ].example_data
388
+ assert spec ["a" , "string" ].example_data == "another string!"
389
+ one = spec .one ()
390
+ assert isinstance (one ["a" ], Scratch )
391
+ assert isinstance (one ["b" ], str )
392
+ assert isinstance (one ["a" ].string , str )
393
+ assert isinstance (one ["a" ].some_object , Whatever )
394
+ assert (one == td ).all ()
360
395
361
396
def test_env_that_does_nothing (self ):
362
397
env = EnvThatDoesNothing ()
0 commit comments