@@ -4616,11 +4616,13 @@ def __next__(self):
4616
4616
@pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4617
4617
@pytest .mark .parametrize ("device" , [None , "cpu" ])
4618
4618
def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4619
- env = LLMEnv (str2str = str2str , device = device )
4619
+ env = LLMEnv (
4620
+ str2str = str2str , device = device , has_attention = False , no_stack = False
4621
+ )
4620
4622
if str2str :
4621
4623
primer = DataLoadingPrimer (
4622
4624
dataloader = self .DummyDataLoader (batch_size = batch_size ),
4623
- data_keys = ["observation" ],
4625
+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
4624
4626
example_data = "a string!" ,
4625
4627
)
4626
4628
else :
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4630
4632
dataloader = self .DummyTensorDataLoader (
4631
4633
batch_size = batch_size , padding = True
4632
4634
),
4633
- data_keys = ["observation" ],
4635
+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
4634
4636
data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4635
4637
stack_method = stack_method ,
4636
4638
)
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4640
4642
if batched :
4641
4643
td = env .reset (TensorDict (batch_size = [3 ]))
4642
4644
env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4643
- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4645
+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4644
4646
else :
4645
4647
env .check_env_specs (break_when_any_done = "both" )
4646
4648
@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
4663
4665
if str2str :
4664
4666
kwargs = {
4665
4667
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4666
- "data_keys" : ["observation" ],
4668
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4667
4669
"example_data" : "a string!" ,
4668
4670
}
4669
4671
else :
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
4673
4675
"dataloader" : self .DummyTensorDataLoader (
4674
4676
padding = True , batch_size = batch_size
4675
4677
),
4676
- "data_keys" : ["observation" ],
4678
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4677
4679
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4678
4680
"stack_method" : stack_method ,
4679
4681
}
4680
- kwargs .update ({"str2str" : str2str , "device" : device })
4682
+ kwargs .update (
4683
+ {
4684
+ "str2str" : str2str ,
4685
+ "device" : device ,
4686
+ "has_attention" : False ,
4687
+ "no_stack" : False ,
4688
+ }
4689
+ )
4681
4690
env = LLMEnv .from_dataloader (** kwargs )
4682
4691
assert not env .batch_locked
4683
4692
if batched :
@@ -4690,46 +4699,64 @@ def test_llm_from_dataloader(
4690
4699
def policy (td ):
4691
4700
if str2str :
4692
4701
if not td .shape :
4693
- td ["action" ] = "<nothing>"
4702
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
4694
4703
else :
4695
- td ["action" ] = NonTensorStack (
4704
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
4696
4705
* ["<nothing>" for _ in range (td .shape [0 ])]
4697
4706
)
4698
4707
else :
4699
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4708
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4709
+ td .shape + (1 ,), dtype = torch .int64
4710
+ )
4700
4711
return td
4701
4712
4702
4713
if batched :
4703
4714
# Tell the env that we want 3 sub-envs
4704
4715
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
4705
4716
assert r .ndim == 2
4706
4717
if str2str :
4707
- assert isinstance (r [0 , 0 ]["observation" ], str )
4708
- assert isinstance (r [0 , 1 ]["observation" ], str )
4718
+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4719
+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4709
4720
assert (
4710
- r [0 , 0 ]["observation" ]
4711
- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4721
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4722
+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4723
+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4724
+ ]
4712
4725
)
4713
4726
assert (
4714
- r [0 , 1 ]["observation" ]
4715
- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4727
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4728
+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4729
+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4730
+ ]
4716
4731
)
4717
4732
assert (
4718
- r [- 1 , 0 ]["observation" ]
4719
- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4733
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4734
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4735
+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4736
+ ]
4720
4737
)
4721
4738
assert (
4722
- r [- 1 , 1 ]["observation" ]
4723
- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4739
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4740
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4741
+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4742
+ ]
4724
4743
)
4725
4744
else :
4726
- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4727
- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
4728
4745
assert (
4729
- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4746
+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4747
+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4748
+ ).all ()
4749
+ assert (
4750
+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4751
+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4730
4752
).all ()
4731
4753
assert (
4732
- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4754
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4755
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4756
+ ).all ()
4757
+ assert (
4758
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4759
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4733
4760
).all ()
4734
4761
else :
4735
4762
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
@@ -4755,7 +4782,7 @@ def test_llm_from_dataloader_repeats(
4755
4782
if str2str :
4756
4783
kwargs = {
4757
4784
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4758
- "data_keys" : ["observation" ],
4785
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4759
4786
"example_data" : "a string!" ,
4760
4787
"repeats" : repeats ,
4761
4788
}
@@ -4766,12 +4793,19 @@ def test_llm_from_dataloader_repeats(
4766
4793
"dataloader" : self .DummyTensorDataLoader (
4767
4794
padding = True , batch_size = batch_size
4768
4795
),
4769
- "data_keys" : ["observation" ],
4796
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4770
4797
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4771
4798
"stack_method" : stack_method ,
4772
4799
"repeats" : repeats ,
4773
4800
}
4774
- kwargs .update ({"str2str" : str2str , "device" : device })
4801
+ kwargs .update (
4802
+ {
4803
+ "str2str" : str2str ,
4804
+ "device" : device ,
4805
+ "has_attention" : False ,
4806
+ "no_stack" : False ,
4807
+ }
4808
+ )
4775
4809
env = LLMEnv .from_dataloader (** kwargs )
4776
4810
assert env .transform .repeats == repeats
4777
4811
@@ -4781,13 +4815,15 @@ def test_llm_from_dataloader_repeats(
4781
4815
def policy (td ):
4782
4816
if str2str :
4783
4817
if not td .shape :
4784
- td ["action" ] = "<nothing>"
4818
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
4785
4819
else :
4786
- td ["action" ] = NonTensorStack (
4820
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
4787
4821
* ["<nothing>" for _ in range (td .shape [0 ])]
4788
4822
)
4789
4823
else :
4790
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4824
+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4825
+ td .shape + (1 ,), dtype = torch .int64
4826
+ )
4791
4827
return td
4792
4828
4793
4829
if batched :
@@ -4803,34 +4839,58 @@ def policy(td):
4803
4839
r_reset = r [..., ::max_steps ]
4804
4840
if not batched :
4805
4841
if str2str :
4806
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4807
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4808
- assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4842
+ assert (
4843
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4844
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4845
+ )
4846
+ assert (
4847
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4848
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4849
+ )
4850
+ assert (
4851
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4852
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4853
+ )
4809
4854
else :
4810
4855
assert (
4811
- r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4856
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4857
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4812
4858
).all ()
4813
4859
assert (
4814
- r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4860
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4861
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4815
4862
).all ()
4816
4863
assert (
4817
- r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4864
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4865
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4818
4866
).any ()
4819
4867
else :
4820
4868
# When batched, each block contains the 3 reset packs
4821
4869
if str2str :
4822
- assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4823
- assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4824
- assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4870
+ assert (
4871
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4872
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4873
+ )
4874
+ assert (
4875
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4876
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4877
+ )
4878
+ assert (
4879
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4880
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4881
+ )
4825
4882
else :
4826
4883
assert (
4827
- r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4884
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4885
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4828
4886
).all ()
4829
4887
assert (
4830
- r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4888
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4889
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4831
4890
).all ()
4832
4891
assert (
4833
- r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4892
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4893
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4834
4894
).any ()
4835
4895
4836
4896
@pytest .mark .parametrize (
@@ -4864,7 +4924,7 @@ def test_done_and_reward(
4864
4924
if str2str :
4865
4925
kwargs = {
4866
4926
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4867
- "data_keys" : ["observation" ],
4927
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4868
4928
"example_data" : "a string!" ,
4869
4929
"repeats" : repeats ,
4870
4930
"assign_reward" : assign_reward ,
@@ -4877,20 +4937,27 @@ def test_done_and_reward(
4877
4937
"dataloader" : self .DummyTensorDataLoader (
4878
4938
padding = True , batch_size = batch_size
4879
4939
),
4880
- "data_keys" : ["observation" ],
4940
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4881
4941
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4882
4942
"stack_method" : stack_method ,
4883
4943
"repeats" : repeats ,
4884
4944
"assign_reward" : assign_reward ,
4885
4945
"assign_done" : assign_done ,
4886
4946
}
4887
- kwargs .update ({"str2str" : str2str , "device" : device })
4947
+ kwargs .update (
4948
+ {
4949
+ "str2str" : str2str ,
4950
+ "device" : device ,
4951
+ "has_attention" : False ,
4952
+ "no_stack" : False ,
4953
+ }
4954
+ )
4888
4955
env = LLMEnv .from_dataloader (** kwargs )
4889
4956
# We want to make sure that transforms that rely on the done state work appropriately
4890
4957
env .append_transform (StepCounter (max_steps = 10 ))
4891
4958
4892
4959
def policy (td ):
4893
- td ["action" ] = torch .ones (
4960
+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = torch .ones (
4894
4961
td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
4895
4962
)
4896
4963
return td
0 commit comments