1616from  collections  import  defaultdict 
1717from  functools  import  partial 
1818from  sys  import  platform 
19- from  typing  import  Optional 
19+ from  typing  import  Any ,  Optional 
2020
2121import  numpy  as  np 
2222import  pytest 
3333    TensorDictBase ,
3434)
3535from  tensordict .nn  import  TensorDictModuleBase 
36- from  tensordict .tensorclass  import  NonTensorStack 
36+ from  tensordict .tensorclass  import  NonTensorStack ,  TensorClass 
3737from  tensordict .utils  import  _unravel_key_to_tuple 
3838from  torch  import  nn 
3939
@@ -340,7 +340,8 @@ def forward(self, values):
340340        )
341341        env .rollout (10 , policy )
342342
343-     def  test_make_spec_from_td (self ):
343+     @pytest .mark .parametrize ("dynamic_shape" , [True , False ]) 
344+     def  test_make_spec_from_td (self , dynamic_shape ):
344345        data  =  TensorDict (
345346            {
346347                "obs" : torch .randn (3 ),
@@ -353,10 +354,44 @@ def test_make_spec_from_td(self):
353354            },
354355            [],
355356        )
356-         spec  =  make_composite_from_td (data )
357+         spec  =  make_composite_from_td (data ,  dynamic_shape = dynamic_shape )
357358        assert  (spec .zero () ==  data .zero_ ()).all ()
358359        for  key , val  in  data .items (True , True ):
359360            assert  val .dtype  is  spec [key ].dtype 
361+         if  dynamic_shape :
362+             assert  all (s .shape [- 1 ] ==  - 1  for  s  in  spec .values (True , True ))
363+ 
364+     def  test_make_spec_from_tc (self ):
365+         class  Scratch (TensorClass ):
366+             obs : torch .Tensor 
367+             string : str 
368+             some_object : Any 
369+ 
370+         class  Whatever :
371+             ...
372+ 
373+         td  =  TensorDict (
374+             a = Scratch (
375+                 obs = torch .ones (5 , 3 ),
376+                 string = "another string!" ,
377+                 some_object = Whatever (),
378+                 batch_size = (5 ,),
379+             ),
380+             b = "a string!" ,
381+             batch_size = (5 ,),
382+         )
383+         spec  =  make_composite_from_td (td )
384+         assert  isinstance (spec , Composite )
385+         assert  isinstance (spec ["a" ], Composite )
386+         assert  isinstance (spec ["b" ], NonTensor )
387+         assert  spec ["b" ].example_data  ==  "a string!" , spec ["b" ].example_data 
388+         assert  spec ["a" , "string" ].example_data  ==  "another string!" 
389+         one  =  spec .one ()
390+         assert  isinstance (one ["a" ], Scratch )
391+         assert  isinstance (one ["b" ], str )
392+         assert  isinstance (one ["a" ].string , str )
393+         assert  isinstance (one ["a" ].some_object , Whatever )
394+         assert  (one  ==  td ).all ()
360395
361396    def  test_env_that_does_nothing (self ):
362397        env  =  EnvThatDoesNothing ()
0 commit comments