@@ -4644,11 +4644,13 @@ def __next__(self):
4644
4644
@pytest .mark .parametrize ("batch_size" , [0 , 4 ])
4645
4645
@pytest .mark .parametrize ("device" , [None , "cpu" ])
4646
4646
def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4647
- env = LLMEnv (str2str = str2str , device = device )
4647
+ env = LLMEnv (
4648
+ str2str = str2str , device = device , has_attention = False , no_stack = False
4649
+ )
4648
4650
if str2str :
4649
4651
primer = DataLoadingPrimer (
4650
4652
dataloader = self .DummyDataLoader (batch_size = batch_size ),
4651
- data_keys = ["observation" ],
4653
+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
4652
4654
example_data = "a string!" ,
4653
4655
)
4654
4656
else :
@@ -4658,7 +4660,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4658
4660
dataloader = self .DummyTensorDataLoader (
4659
4661
batch_size = batch_size , padding = True
4660
4662
),
4661
- data_keys = ["observation" ],
4663
+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
4662
4664
data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4663
4665
stack_method = stack_method ,
4664
4666
)
@@ -4668,7 +4670,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4668
4670
if batched :
4669
4671
td = env .reset (TensorDict (batch_size = [3 ]))
4670
4672
env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4671
- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4673
+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4672
4674
else :
4673
4675
env .check_env_specs (break_when_any_done = "both" )
4674
4676
@@ -4691,7 +4693,7 @@ def test_llm_from_dataloader(
4691
4693
if str2str :
4692
4694
kwargs = {
4693
4695
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4694
- "data_keys" : ["observation" ],
4696
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4695
4697
"example_data" : "a string!" ,
4696
4698
}
4697
4699
else :
@@ -4701,11 +4703,18 @@ def test_llm_from_dataloader(
4701
4703
"dataloader" : self .DummyTensorDataLoader (
4702
4704
padding = True , batch_size = batch_size
4703
4705
),
4704
- "data_keys" : ["observation" ],
4706
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4705
4707
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4706
4708
"stack_method" : stack_method ,
4707
4709
}
4708
- kwargs .update ({"str2str" : str2str , "device" : device })
4710
+ kwargs .update (
4711
+ {
4712
+ "str2str" : str2str ,
4713
+ "device" : device ,
4714
+ "has_attention" : False ,
4715
+ "no_stack" : False ,
4716
+ }
4717
+ )
4709
4718
env = LLMEnv .from_dataloader (** kwargs )
4710
4719
assert not env .batch_locked
4711
4720
if batched :
@@ -4718,46 +4727,64 @@ def test_llm_from_dataloader(
4718
4727
def policy (td ):
4719
4728
if str2str :
4720
4729
if not td .shape :
4721
- td ["action" ] = "<nothing>"
4730
+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = "<nothing>"
4722
4731
else :
4723
- td ["action" ] = NonTensorStack (
4732
+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
4724
4733
* ["<nothing>" for _ in range (td .shape [0 ])]
4725
4734
)
4726
4735
else :
4727
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4736
+ td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
4737
+ td .shape + (1 ,), dtype = torch .int64
4738
+ )
4728
4739
return td
4729
4740
4730
4741
if batched :
4731
4742
# Tell the env that we want 3 sub-envs
4732
4743
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
4733
4744
assert r .ndim == 2
4734
4745
if str2str :
4735
- assert isinstance (r [0 , 0 ]["observation" ], str )
4736
- assert isinstance (r [0 , 1 ]["observation" ], str )
4746
+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4747
+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4737
4748
assert (
4738
- r [0 , 0 ]["observation" ]
4739
- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4749
+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4750
+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4751
+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4752
+ ]
4740
4753
)
4741
4754
assert (
4742
- r [0 , 1 ]["observation" ]
4743
- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4755
+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4756
+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4757
+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4758
+ ]
4744
4759
)
4745
4760
assert (
4746
- r [- 1 , 0 ]["observation" ]
4747
- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4761
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4762
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4763
+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4764
+ ]
4748
4765
)
4749
4766
assert (
4750
- r [- 1 , 1 ]["observation" ]
4751
- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4767
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4768
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4769
+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_STR_KEY ])
4770
+ ]
4752
4771
)
4753
4772
else :
4754
- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4755
- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
4756
4773
assert (
4757
- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4774
+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4775
+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4776
+ ).all ()
4777
+ assert (
4778
+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4779
+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4758
4780
).all ()
4759
4781
assert (
4760
- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4782
+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4783
+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4784
+ ).all ()
4785
+ assert (
4786
+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4787
+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4761
4788
).all ()
4762
4789
else :
4763
4790
r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
@@ -4783,7 +4810,7 @@ def test_llm_from_dataloader_repeats(
4783
4810
if str2str :
4784
4811
kwargs = {
4785
4812
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4786
- "data_keys" : ["observation" ],
4813
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4787
4814
"example_data" : "a string!" ,
4788
4815
"repeats" : repeats ,
4789
4816
}
@@ -4794,12 +4821,19 @@ def test_llm_from_dataloader_repeats(
4794
4821
"dataloader" : self .DummyTensorDataLoader (
4795
4822
padding = True , batch_size = batch_size
4796
4823
),
4797
- "data_keys" : ["observation" ],
4824
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4798
4825
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4799
4826
"stack_method" : stack_method ,
4800
4827
"repeats" : repeats ,
4801
4828
}
4802
- kwargs .update ({"str2str" : str2str , "device" : device })
4829
+ kwargs .update (
4830
+ {
4831
+ "str2str" : str2str ,
4832
+ "device" : device ,
4833
+ "has_attention" : False ,
4834
+ "no_stack" : False ,
4835
+ }
4836
+ )
4803
4837
env = LLMEnv .from_dataloader (** kwargs )
4804
4838
assert env .transform .repeats == repeats
4805
4839
@@ -4809,13 +4843,15 @@ def test_llm_from_dataloader_repeats(
4809
4843
def policy (td ):
4810
4844
if str2str :
4811
4845
if not td .shape :
4812
- td ["action" ] = "<nothing>"
4846
+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = "<nothing>"
4813
4847
else :
4814
- td ["action" ] = NonTensorStack (
4848
+ td [LLMEnv . _DEFAULT_ACTION_STR_KEY ] = NonTensorStack (
4815
4849
* ["<nothing>" for _ in range (td .shape [0 ])]
4816
4850
)
4817
4851
else :
4818
- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4852
+ td [LLMEnv ._DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
4853
+ td .shape + (1 ,), dtype = torch .int64
4854
+ )
4819
4855
return td
4820
4856
4821
4857
if batched :
@@ -4831,34 +4867,58 @@ def policy(td):
4831
4867
r_reset = r [..., ::max_steps ]
4832
4868
if not batched :
4833
4869
if str2str :
4834
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4835
- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4836
- assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4870
+ assert (
4871
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4872
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4873
+ )
4874
+ assert (
4875
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4876
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4877
+ )
4878
+ assert (
4879
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4880
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4881
+ )
4837
4882
else :
4838
4883
assert (
4839
- r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4884
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4885
+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4840
4886
).all ()
4841
4887
assert (
4842
- r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4888
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4889
+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4843
4890
).all ()
4844
4891
assert (
4845
- r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4892
+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4893
+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4846
4894
).any ()
4847
4895
else :
4848
4896
# When batched, each block contains the 3 reset packs
4849
4897
if str2str :
4850
- assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4851
- assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4852
- assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4898
+ assert (
4899
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4900
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4901
+ )
4902
+ assert (
4903
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4904
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4905
+ )
4906
+ assert (
4907
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4908
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4909
+ )
4853
4910
else :
4854
4911
assert (
4855
- r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4912
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4913
+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4856
4914
).all ()
4857
4915
assert (
4858
- r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4916
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4917
+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4859
4918
).all ()
4860
4919
assert (
4861
- r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4920
+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4921
+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4862
4922
).any ()
4863
4923
4864
4924
@pytest .mark .parametrize (
@@ -4892,7 +4952,7 @@ def test_done_and_reward(
4892
4952
if str2str :
4893
4953
kwargs = {
4894
4954
"dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4895
- "data_keys" : ["observation" ],
4955
+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
4896
4956
"example_data" : "a string!" ,
4897
4957
"repeats" : repeats ,
4898
4958
"assign_reward" : assign_reward ,
@@ -4905,20 +4965,27 @@ def test_done_and_reward(
4905
4965
"dataloader" : self .DummyTensorDataLoader (
4906
4966
padding = True , batch_size = batch_size
4907
4967
),
4908
- "data_keys" : ["observation" ],
4968
+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
4909
4969
"data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
4910
4970
"stack_method" : stack_method ,
4911
4971
"repeats" : repeats ,
4912
4972
"assign_reward" : assign_reward ,
4913
4973
"assign_done" : assign_done ,
4914
4974
}
4915
- kwargs .update ({"str2str" : str2str , "device" : device })
4975
+ kwargs .update (
4976
+ {
4977
+ "str2str" : str2str ,
4978
+ "device" : device ,
4979
+ "has_attention" : False ,
4980
+ "no_stack" : False ,
4981
+ }
4982
+ )
4916
4983
env = LLMEnv .from_dataloader (** kwargs )
4917
4984
# We want to make sure that transforms that rely on the done state work appropriately
4918
4985
env .append_transform (StepCounter (max_steps = 10 ))
4919
4986
4920
4987
def policy (td ):
4921
- td ["action" ] = torch .ones (
4988
+ td [LLMEnv . _DEFAULT_ACTION_TOKENS_KEY ] = torch .ones (
4922
4989
td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
4923
4990
)
4924
4991
return td
0 commit comments