@@ -4616,11 +4616,13 @@ def __next__(self):
46164616 @pytest .mark .parametrize ("batch_size" , [0 , 4 ])
46174617 @pytest .mark .parametrize ("device" , [None , "cpu" ])
46184618 def test_llm_env (self , str2str , batched , stack_method , device , batch_size ):
4619- env = LLMEnv (str2str = str2str , device = device )
4619+ env = LLMEnv (
4620+ str2str = str2str , device = device , has_attention = False , no_stack = False
4621+ )
46204622 if str2str :
46214623 primer = DataLoadingPrimer (
46224624 dataloader = self .DummyDataLoader (batch_size = batch_size ),
4623- data_keys = ["observation" ],
4625+ data_keys = [LLMEnv . _DEFAULT_STR_KEY ],
46244626 example_data = "a string!" ,
46254627 )
46264628 else :
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46304632 dataloader = self .DummyTensorDataLoader (
46314633 batch_size = batch_size , padding = True
46324634 ),
4633- data_keys = ["observation" ],
4635+ data_keys = [LLMEnv . _DEFAULT_TOKEN_KEY ],
46344636 data_specs = [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
46354637 stack_method = stack_method ,
46364638 )
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46404642 if batched :
46414643 td = env .reset (TensorDict (batch_size = [3 ]))
46424644 env .check_env_specs (break_when_any_done = "both" , tensordict = td )
4643- r = env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
4645+ env .rollout (10 , tensordict = TensorDict (batch_size = [3 ]))
46444646 else :
46454647 env .check_env_specs (break_when_any_done = "both" )
46464648
@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
46634665 if str2str :
46644666 kwargs = {
46654667 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4666- "data_keys" : ["observation" ],
4668+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
46674669 "example_data" : "a string!" ,
46684670 }
46694671 else :
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
46734675 "dataloader" : self .DummyTensorDataLoader (
46744676 padding = True , batch_size = batch_size
46754677 ),
4676- "data_keys" : ["observation" ],
4678+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
46774679 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
46784680 "stack_method" : stack_method ,
46794681 }
4680- kwargs .update ({"str2str" : str2str , "device" : device })
4682+ kwargs .update (
4683+ {
4684+ "str2str" : str2str ,
4685+ "device" : device ,
4686+ "has_attention" : False ,
4687+ "no_stack" : False ,
4688+ }
4689+ )
46814690 env = LLMEnv .from_dataloader (** kwargs )
46824691 assert not env .batch_locked
46834692 if batched :
@@ -4690,46 +4699,64 @@ def test_llm_from_dataloader(
46904699 def policy (td ):
46914700 if str2str :
46924701 if not td .shape :
4693- td ["action" ] = "<nothing>"
4702+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
46944703 else :
4695- td ["action" ] = NonTensorStack (
4704+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
46964705 * ["<nothing>" for _ in range (td .shape [0 ])]
46974706 )
46984707 else :
4699- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4708+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4709+ td .shape + (1 ,), dtype = torch .int64
4710+ )
47004711 return td
47014712
47024713 if batched :
47034714 # Tell the env that we want 3 sub-envs
47044715 r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = [3 ]))
47054716 assert r .ndim == 2
47064717 if str2str :
4707- assert isinstance (r [0 , 0 ]["observation" ], str )
4708- assert isinstance (r [0 , 1 ]["observation" ], str )
4718+ assert isinstance (r [0 , 0 ][LLMEnv . _DEFAULT_STR_KEY ], str )
4719+ assert isinstance (r [0 , 1 ][LLMEnv . _DEFAULT_STR_KEY ], str )
47094720 assert (
4710- r [0 , 0 ]["observation" ]
4711- == r [0 , 1 ]["observation" ][: - len (r [0 , 0 ]["action" ])]
4721+ r [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4722+ == r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4723+ : - len (r [0 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4724+ ]
47124725 )
47134726 assert (
4714- r [0 , 1 ]["observation" ]
4715- == r [0 , 2 ]["observation" ][: - len (r [0 , 1 ]["action" ])]
4727+ r [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4728+ == r [0 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4729+ : - len (r [0 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4730+ ]
47164731 )
47174732 assert (
4718- r [- 1 , 0 ]["observation" ]
4719- == r [- 1 , 1 ]["observation" ][: - len (r [- 1 , 0 ]["action" ])]
4733+ r [- 1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4734+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ][
4735+ : - len (r [- 1 , 0 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4736+ ]
47204737 )
47214738 assert (
4722- r [- 1 , 1 ]["observation" ]
4723- == r [- 1 , 2 ]["observation" ][: - len (r [- 1 , 1 ]["action" ])]
4739+ r [- 1 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4740+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_STR_KEY ][
4741+ : - len (r [- 1 , 1 ][LLMEnv ._DEFAULT_ACTION_KEY ])
4742+ ]
47244743 )
47254744 else :
4726- assert (r [0 , 0 ]["observation" ] == r [0 , 1 ]["observation" ][:- 1 ]).all ()
4727- assert (r [0 , 1 ]["observation" ] == r [0 , 2 ]["observation" ][:- 1 ]).all ()
47284745 assert (
4729- r [- 1 , 0 ]["observation" ] == r [- 1 , 1 ]["observation" ][:- 1 ]
4746+ r [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4747+ == r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4748+ ).all ()
4749+ assert (
4750+ r [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4751+ == r [0 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
47304752 ).all ()
47314753 assert (
4732- r [- 1 , 1 ]["observation" ] == r [- 1 , 2 ]["observation" ][:- 1 ]
4754+ r [- 1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4755+ == r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
4756+ ).all ()
4757+ assert (
4758+ r [- 1 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4759+ == r [- 1 , 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ][:- 1 ]
47334760 ).all ()
47344761 else :
47354762 r = env .rollout (10 , policy , tensordict = TensorDict (batch_size = []))
@@ -4755,7 +4782,7 @@ def test_llm_from_dataloader_repeats(
47554782 if str2str :
47564783 kwargs = {
47574784 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4758- "data_keys" : ["observation" ],
4785+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
47594786 "example_data" : "a string!" ,
47604787 "repeats" : repeats ,
47614788 }
@@ -4766,12 +4793,19 @@ def test_llm_from_dataloader_repeats(
47664793 "dataloader" : self .DummyTensorDataLoader (
47674794 padding = True , batch_size = batch_size
47684795 ),
4769- "data_keys" : ["observation" ],
4796+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
47704797 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
47714798 "stack_method" : stack_method ,
47724799 "repeats" : repeats ,
47734800 }
4774- kwargs .update ({"str2str" : str2str , "device" : device })
4801+ kwargs .update (
4802+ {
4803+ "str2str" : str2str ,
4804+ "device" : device ,
4805+ "has_attention" : False ,
4806+ "no_stack" : False ,
4807+ }
4808+ )
47754809 env = LLMEnv .from_dataloader (** kwargs )
47764810 assert env .transform .repeats == repeats
47774811
@@ -4781,13 +4815,15 @@ def test_llm_from_dataloader_repeats(
47814815 def policy (td ):
47824816 if str2str :
47834817 if not td .shape :
4784- td ["action" ] = "<nothing>"
4818+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = "<nothing>"
47854819 else :
4786- td ["action" ] = NonTensorStack (
4820+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = NonTensorStack (
47874821 * ["<nothing>" for _ in range (td .shape [0 ])]
47884822 )
47894823 else :
4790- td ["action" ] = torch .ones (td .shape + (1 ,), dtype = torch .int64 )
4824+ td [LLMEnv ._DEFAULT_ACTION_KEY ] = torch .ones (
4825+ td .shape + (1 ,), dtype = torch .int64
4826+ )
47914827 return td
47924828
47934829 if batched :
@@ -4803,34 +4839,58 @@ def policy(td):
48034839 r_reset = r [..., ::max_steps ]
48044840 if not batched :
48054841 if str2str :
4806- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4807- assert r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4808- assert r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4842+ assert (
4843+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4844+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4845+ )
4846+ assert (
4847+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4848+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_STR_KEY ]
4849+ )
4850+ assert (
4851+ r_reset [..., 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4852+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_STR_KEY ]
4853+ )
48094854 else :
48104855 assert (
4811- r_reset [..., 0 ]["observation" ] == r_reset [..., 1 ]["observation" ]
4856+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4857+ == r_reset [..., 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48124858 ).all ()
48134859 assert (
4814- r_reset [..., 0 ]["observation" ] == r_reset [..., 2 ]["observation" ]
4860+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4861+ == r_reset [..., 2 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48154862 ).all ()
48164863 assert (
4817- r_reset [..., 0 ]["observation" ] != r_reset [..., 3 ]["observation" ]
4864+ r_reset [..., 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4865+ != r_reset [..., 3 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48184866 ).any ()
48194867 else :
48204868 # When batched, each block contains the 3 reset packs
48214869 if str2str :
4822- assert r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4823- assert r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4824- assert r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4870+ assert (
4871+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4872+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4873+ )
4874+ assert (
4875+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4876+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4877+ )
4878+ assert (
4879+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_STR_KEY ]
4880+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_STR_KEY ]
4881+ )
48254882 else :
48264883 assert (
4827- r_reset [0 , 0 ]["observation" ] == r_reset [1 , 0 ]["observation" ]
4884+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4885+ == r_reset [1 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48284886 ).all ()
48294887 assert (
4830- r_reset [0 , 0 ]["observation" ] == r_reset [2 , 0 ]["observation" ]
4888+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4889+ == r_reset [2 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48314890 ).all ()
48324891 assert (
4833- r_reset [0 , 0 ]["observation" ] != r_reset [0 , 1 ]["observation" ]
4892+ r_reset [0 , 0 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
4893+ != r_reset [0 , 1 ][LLMEnv ._DEFAULT_TOKEN_KEY ]
48344894 ).any ()
48354895
48364896 @pytest .mark .parametrize (
@@ -4864,7 +4924,7 @@ def test_done_and_reward(
48644924 if str2str :
48654925 kwargs = {
48664926 "dataloader" : self .DummyDataLoader (batch_size = batch_size ),
4867- "data_keys" : ["observation" ],
4927+ "data_keys" : [LLMEnv . _DEFAULT_STR_KEY ],
48684928 "example_data" : "a string!" ,
48694929 "repeats" : repeats ,
48704930 "assign_reward" : assign_reward ,
@@ -4877,20 +4937,27 @@ def test_done_and_reward(
48774937 "dataloader" : self .DummyTensorDataLoader (
48784938 padding = True , batch_size = batch_size
48794939 ),
4880- "data_keys" : ["observation" ],
4940+ "data_keys" : [LLMEnv . _DEFAULT_TOKEN_KEY ],
48814941 "data_specs" : [Unbounded (shape = (- 1 ,), dtype = torch .int64 )],
48824942 "stack_method" : stack_method ,
48834943 "repeats" : repeats ,
48844944 "assign_reward" : assign_reward ,
48854945 "assign_done" : assign_done ,
48864946 }
4887- kwargs .update ({"str2str" : str2str , "device" : device })
4947+ kwargs .update (
4948+ {
4949+ "str2str" : str2str ,
4950+ "device" : device ,
4951+ "has_attention" : False ,
4952+ "no_stack" : False ,
4953+ }
4954+ )
48884955 env = LLMEnv .from_dataloader (** kwargs )
48894956 # We want to make sure that transforms that rely on the done state work appropriately
48904957 env .append_transform (StepCounter (max_steps = 10 ))
48914958
48924959 def policy (td ):
4893- td ["action" ] = torch .ones (
4960+ td [LLMEnv . _DEFAULT_ACTION_KEY ] = torch .ones (
48944961 td .shape + (torch .randint (10 , (1 ,)).item (),), dtype = torch .int64
48954962 )
48964963 return td
0 commit comments