Skip to content

Commit eea932c

Browse files
committed
[Feature] transformers policy
ghstack-source-id: 870c221b4ebae132a44944f0be0ee78da540d115 Pull Request resolved: #2825
1 parent 528986b commit eea932c

File tree

12 files changed

+737
-185
lines changed

12 files changed

+737
-185
lines changed

test/test_actors.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import pytest
99
import torch
1010

11-
from tensordict import TensorDict
11+
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
1717
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
18+
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,
2121
ActorValueOperator,
@@ -907,6 +907,55 @@ def test_lmhead_actorvalueoperator(device):
907907
) == len(policy_params)
908908

909909

910+
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
911+
class TestTransformerActor:
912+
@pytest.mark.parametrize(
913+
"from_text, generate, tokens, attention_mask",
914+
[
915+
(True, True, None, None),
916+
(True, False, None, None),
917+
(
918+
False,
919+
True,
920+
torch.randint(1024, (1, 10)),
921+
torch.ones(1, 10, dtype=torch.int64),
922+
),
923+
(False, True, torch.randint(1024, (1, 10)), None),
924+
],
925+
)
926+
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927+
from torchrl.data.llm import LLMData
928+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929+
930+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
931+
tokenizer.pad_token = tokenizer.eos_token
932+
model = GPT2LMHeadModel(GPT2Config())
933+
tokenizer.padding_side = "left"
934+
m = from_hf_transformers(
935+
model, tokenizer=tokenizer, from_text=from_text, generate=generate
936+
)
937+
if from_text:
938+
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
939+
else:
940+
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
941+
td = m(tdin)
942+
assert td is tdin
943+
assert isinstance(td, LLMData)
944+
if from_text and generate:
945+
assert td.text_response is not None
946+
else:
947+
assert td.text_response is None
948+
if attention_mask is not None or from_text:
949+
assert td.attention_mask is not None
950+
else:
951+
assert td.attention_mask is None
952+
if not generate:
953+
assert td.text_response is None
954+
assert td.tokens_response is None
955+
assert td.log_probs is not None
956+
assert td.logits is not None
957+
958+
910959
if __name__ == "__main__":
911960
args, unknown = argparse.ArgumentParser().parse_known_args()
912961
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_env.py

+113-46
Original file line numberDiff line numberDiff line change
@@ -4644,11 +4644,13 @@ def __next__(self):
46444644
@pytest.mark.parametrize("batch_size", [0, 4])
46454645
@pytest.mark.parametrize("device", [None, "cpu"])
46464646
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4647-
env = LLMEnv(str2str=str2str, device=device)
4647+
env = LLMEnv(
4648+
str2str=str2str, device=device, has_attention=False, no_stack=False
4649+
)
46484650
if str2str:
46494651
primer = DataLoadingPrimer(
46504652
dataloader=self.DummyDataLoader(batch_size=batch_size),
4651-
data_keys=["observation"],
4653+
data_keys=[LLMEnv._DEFAULT_STR_KEY],
46524654
example_data="a string!",
46534655
)
46544656
else:
@@ -4658,7 +4660,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46584660
dataloader=self.DummyTensorDataLoader(
46594661
batch_size=batch_size, padding=True
46604662
),
4661-
data_keys=["observation"],
4663+
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
46624664
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
46634665
stack_method=stack_method,
46644666
)
@@ -4668,7 +4670,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46684670
if batched:
46694671
td = env.reset(TensorDict(batch_size=[3]))
46704672
env.check_env_specs(break_when_any_done="both", tensordict=td)
4671-
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
4673+
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46724674
else:
46734675
env.check_env_specs(break_when_any_done="both")
46744676

@@ -4691,7 +4693,7 @@ def test_llm_from_dataloader(
46914693
if str2str:
46924694
kwargs = {
46934695
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4694-
"data_keys": ["observation"],
4696+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
46954697
"example_data": "a string!",
46964698
}
46974699
else:
@@ -4701,11 +4703,18 @@ def test_llm_from_dataloader(
47014703
"dataloader": self.DummyTensorDataLoader(
47024704
padding=True, batch_size=batch_size
47034705
),
4704-
"data_keys": ["observation"],
4706+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
47054707
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
47064708
"stack_method": stack_method,
47074709
}
4708-
kwargs.update({"str2str": str2str, "device": device})
4710+
kwargs.update(
4711+
{
4712+
"str2str": str2str,
4713+
"device": device,
4714+
"has_attention": False,
4715+
"no_stack": False,
4716+
}
4717+
)
47094718
env = LLMEnv.from_dataloader(**kwargs)
47104719
assert not env.batch_locked
47114720
if batched:
@@ -4718,46 +4727,64 @@ def test_llm_from_dataloader(
47184727
def policy(td):
47194728
if str2str:
47204729
if not td.shape:
4721-
td["action"] = "<nothing>"
4730+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
47224731
else:
4723-
td["action"] = NonTensorStack(
4732+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
47244733
*["<nothing>" for _ in range(td.shape[0])]
47254734
)
47264735
else:
4727-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4736+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
4737+
td.shape + (1,), dtype=torch.int64
4738+
)
47284739
return td
47294740

47304741
if batched:
47314742
# Tell the env that we want 3 sub-envs
47324743
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47334744
assert r.ndim == 2
47344745
if str2str:
4735-
assert isinstance(r[0, 0]["observation"], str)
4736-
assert isinstance(r[0, 1]["observation"], str)
4746+
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
4747+
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
47374748
assert (
4738-
r[0, 0]["observation"]
4739-
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4749+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4750+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4751+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4752+
]
47404753
)
47414754
assert (
4742-
r[0, 1]["observation"]
4743-
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4755+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4756+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4757+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4758+
]
47444759
)
47454760
assert (
4746-
r[-1, 0]["observation"]
4747-
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4761+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4762+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4763+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4764+
]
47484765
)
47494766
assert (
4750-
r[-1, 1]["observation"]
4751-
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4767+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4768+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4769+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4770+
]
47524771
)
47534772
else:
4754-
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4755-
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
47564773
assert (
4757-
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4774+
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4775+
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4776+
).all()
4777+
assert (
4778+
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4779+
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47584780
).all()
47594781
assert (
4760-
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4782+
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4783+
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4784+
).all()
4785+
assert (
4786+
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4787+
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47614788
).all()
47624789
else:
47634790
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
@@ -4783,7 +4810,7 @@ def test_llm_from_dataloader_repeats(
47834810
if str2str:
47844811
kwargs = {
47854812
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4786-
"data_keys": ["observation"],
4813+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
47874814
"example_data": "a string!",
47884815
"repeats": repeats,
47894816
}
@@ -4794,12 +4821,19 @@ def test_llm_from_dataloader_repeats(
47944821
"dataloader": self.DummyTensorDataLoader(
47954822
padding=True, batch_size=batch_size
47964823
),
4797-
"data_keys": ["observation"],
4824+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
47984825
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
47994826
"stack_method": stack_method,
48004827
"repeats": repeats,
48014828
}
4802-
kwargs.update({"str2str": str2str, "device": device})
4829+
kwargs.update(
4830+
{
4831+
"str2str": str2str,
4832+
"device": device,
4833+
"has_attention": False,
4834+
"no_stack": False,
4835+
}
4836+
)
48034837
env = LLMEnv.from_dataloader(**kwargs)
48044838
assert env.transform.repeats == repeats
48054839

@@ -4809,13 +4843,15 @@ def test_llm_from_dataloader_repeats(
48094843
def policy(td):
48104844
if str2str:
48114845
if not td.shape:
4812-
td["action"] = "<nothing>"
4846+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
48134847
else:
4814-
td["action"] = NonTensorStack(
4848+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
48154849
*["<nothing>" for _ in range(td.shape[0])]
48164850
)
48174851
else:
4818-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4852+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
4853+
td.shape + (1,), dtype=torch.int64
4854+
)
48194855
return td
48204856

48214857
if batched:
@@ -4831,34 +4867,58 @@ def policy(td):
48314867
r_reset = r[..., ::max_steps]
48324868
if not batched:
48334869
if str2str:
4834-
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4835-
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4836-
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4870+
assert (
4871+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4872+
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4873+
)
4874+
assert (
4875+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4876+
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4877+
)
4878+
assert (
4879+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4880+
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4881+
)
48374882
else:
48384883
assert (
4839-
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4884+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4885+
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
48404886
).all()
48414887
assert (
4842-
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4888+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4889+
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
48434890
).all()
48444891
assert (
4845-
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4892+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4893+
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
48464894
).any()
48474895
else:
48484896
# When batched, each block contains the 3 reset packs
48494897
if str2str:
4850-
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4851-
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4852-
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4898+
assert (
4899+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4900+
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
4901+
)
4902+
assert (
4903+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4904+
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
4905+
)
4906+
assert (
4907+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4908+
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
4909+
)
48534910
else:
48544911
assert (
4855-
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4912+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4913+
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48564914
).all()
48574915
assert (
4858-
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4916+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4917+
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48594918
).all()
48604919
assert (
4861-
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4920+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4921+
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
48624922
).any()
48634923

48644924
@pytest.mark.parametrize(
@@ -4892,7 +4952,7 @@ def test_done_and_reward(
48924952
if str2str:
48934953
kwargs = {
48944954
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4895-
"data_keys": ["observation"],
4955+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
48964956
"example_data": "a string!",
48974957
"repeats": repeats,
48984958
"assign_reward": assign_reward,
@@ -4905,20 +4965,27 @@ def test_done_and_reward(
49054965
"dataloader": self.DummyTensorDataLoader(
49064966
padding=True, batch_size=batch_size
49074967
),
4908-
"data_keys": ["observation"],
4968+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
49094969
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
49104970
"stack_method": stack_method,
49114971
"repeats": repeats,
49124972
"assign_reward": assign_reward,
49134973
"assign_done": assign_done,
49144974
}
4915-
kwargs.update({"str2str": str2str, "device": device})
4975+
kwargs.update(
4976+
{
4977+
"str2str": str2str,
4978+
"device": device,
4979+
"has_attention": False,
4980+
"no_stack": False,
4981+
}
4982+
)
49164983
env = LLMEnv.from_dataloader(**kwargs)
49174984
# We want to make sure that transforms that rely on the done state work appropriately
49184985
env.append_transform(StepCounter(max_steps=10))
49194986

49204987
def policy(td):
4921-
td["action"] = torch.ones(
4988+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
49224989
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
49234990
)
49244991
return td

torchrl/data/llm/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,12 @@ class LLMData(TensorClass["nocast"]):
626626
627627
"""
628628

629-
tokens: torch.Tensor
629+
tokens: torch.Tensor | None = None
630630
tokens_response: torch.Tensor | None = None
631631
attention_mask: torch.Tensor | None = None
632632
token_list: list[int] | list[list[int]] | None = None
633633
tokens_response_list: list[list[int]] | None = None
634634
logits: torch.Tensor | None = None
635635
log_probs: torch.Tensor | None = None
636636
text: str | list[str] | None = None
637+
text_response: torch.Tensor | None = None

0 commit comments

Comments
 (0)