@@ -4700,6 +4700,104 @@ def policy(td):
4700
4700
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
4701
4701
assert r .ndim == 1
4702
4702
4703
+ @pytest .mark .parametrize (
4704
+ "str2str,stack_method" ,
4705
+ [
4706
+ [True , None ],
4707
+ [False , "as_padded_tensor" ],
4708
+ # TODO: a bit experimental, fails with check_env_specs
4709
+ # [False, "as_nested_tensor"],
4710
+ [False , None ],
4711
+ ],
4712
+ )
4713
+ @pytest .mark .parametrize ("batched" , [True , False ])
4714
+ @pytest .mark .parametrize ("device" , [None , "cpu" ])
4715
+ @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4716
+ @pytest .mark .parametrize ("repeats" , [3 ])
4717
+ def test_llm_from_dataloader_repeats (
4718
+ self , str2str , batched , stack_method , device , batch_size , repeats
4719
+ ):
4720
+ if str2str :
4721
+ kwargs = {
4722
+ "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4723
+ "data_keys" : ["observation" ],
4724
+ "example_data" : "a string!" ,
4725
+ "repeats" : repeats ,
4726
+ }
4727
+ else :
4728
+ if stack_method is None :
4729
+ stack_method = as_padded_tensor
4730
+ kwargs = {
4731
+ "dataloader" : self .DummyTensorDataLoader (
4732
+ padding = True , batch_size = batch_size
4733
+ ),
4734
+ "data_keys" : ["observation" ],
4735
+ "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4736
+ "stack_method" : stack_method ,
4737
+ "repeats" : repeats ,
4738
+ }
4739
+ kwargs .update ({"str2str" : str2str , "device" : device })
4740
+ env = LLMEnv .from_dataloader (** kwargs )
4741
+ assert env .transform .repeats == repeats
4742
+
4743
+ max_steps = 3
4744
+ env .append_transform (StepCounter (max_steps = max_steps ))
4745
+
4746
+ def policy (td ):
4747
+ if str2str :
4748
+ if not td .shape :
4749
+ td ["action" ] = "<nothing>"
4750
+ else :
4751
+ td ["action" ] = NonTensorStack (
4752
+ * ["<nothing>" for _ in range (td .shape [0 ])]
4753
+ )
4754
+ else :
4755
+ td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4756
+ return td
4757
+
4758
+ if batched :
4759
+ r = env .rollout (
4760
+ 100 ,
4761
+ policy ,
4762
+ tensordict = TensorDict (batch_size = [3 ]),
4763
+ break_when_any_done = False ,
4764
+ )
4765
+ else :
4766
+ r = env .rollout (100 , policy , break_when_any_done = False )
4767
+ # check that r at reset is always the same
4768
+ r_reset = r [..., ::max_steps ]
4769
+ if not batched :
4770
+ if str2str :
4771
+ assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4772
+ assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4773
+ assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4774
+ else :
4775
+ assert (
4776
+ r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4777
+ ).all ()
4778
+ assert (
4779
+ r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4780
+ ).all ()
4781
+ assert (
4782
+ r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4783
+ ).any ()
4784
+ else :
4785
+ # When batched, each block contains the 3 reset packs
4786
+ if str2str :
4787
+ assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4788
+ assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4789
+ assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4790
+ else :
4791
+ assert (
4792
+ r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4793
+ ).all ()
4794
+ assert (
4795
+ r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4796
+ ).all ()
4797
+ assert (
4798
+ r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4799
+ ).any ()
4800
+
4703
4801
4704
4802
if __name__ == "__main__" :
4705
4803
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments