Skip to content

Commit 8ace7bc

Browse files
committed
[Feature] transformers policy
ghstack-source-id: f5e1e6b Pull Request resolved: #2825
1 parent 0c19d4e commit 8ace7bc

File tree

16 files changed

+760
-201
lines changed

16 files changed

+760
-201
lines changed

test/test_actors.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import pytest
99
import torch
1010

11-
from tensordict import TensorDict
11+
from tensordict import NonTensorStack, TensorDict
1212
from tensordict.nn import CompositeDistribution, TensorDictModule
1313
from tensordict.nn.distributions import NormalParamExtractor
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
1717
from torchrl.data.llm.dataset import _has_transformers
18-
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
18+
from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,
2121
ActorValueOperator,
@@ -907,6 +907,55 @@ def test_lmhead_actorvalueoperator(device):
907907
) == len(policy_params)
908908

909909

910+
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
911+
class TestTransformerActor:
912+
@pytest.mark.parametrize(
913+
"from_text, generate, tokens, attention_mask",
914+
[
915+
(True, True, None, None),
916+
(True, False, None, None),
917+
(
918+
False,
919+
True,
920+
torch.randint(1024, (1, 10)),
921+
torch.ones(1, 10, dtype=torch.int64),
922+
),
923+
(False, True, torch.randint(1024, (1, 10)), None),
924+
],
925+
)
926+
def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask):
927+
from torchrl.data.llm import LLMData
928+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
929+
930+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
931+
tokenizer.pad_token = tokenizer.eos_token
932+
model = GPT2LMHeadModel(GPT2Config())
933+
tokenizer.padding_side = "left"
934+
m = from_hf_transformers(
935+
model, tokenizer=tokenizer, from_text=from_text, generate=generate
936+
)
937+
if from_text:
938+
tdin = LLMData(text=NonTensorStack("a text"), batch_size=1)
939+
else:
940+
tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1)
941+
td = m(tdin)
942+
assert td is tdin
943+
assert isinstance(td, LLMData)
944+
if from_text and generate:
945+
assert td.text_response is not None
946+
else:
947+
assert td.text_response is None
948+
if attention_mask is not None or from_text:
949+
assert td.attention_mask is not None
950+
else:
951+
assert td.attention_mask is None
952+
if not generate:
953+
assert td.text_response is None
954+
assert td.tokens_response is None
955+
assert td.log_probs is not None
956+
assert td.logits is not None
957+
958+
910959
if __name__ == "__main__":
911960
args, unknown = argparse.ArgumentParser().parse_known_args()
912961
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_env.py

+113-47
Original file line numberDiff line numberDiff line change
@@ -4616,11 +4616,13 @@ def __next__(self):
46164616
@pytest.mark.parametrize("batch_size", [0, 4])
46174617
@pytest.mark.parametrize("device", [None, "cpu"])
46184618
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4619-
env = LLMEnv(str2str=str2str, device=device)
4619+
env = LLMEnv(
4620+
str2str=str2str, device=device, has_attention=False, no_stack=False
4621+
)
46204622
if str2str:
46214623
primer = DataLoadingPrimer(
46224624
dataloader=self.DummyDataLoader(batch_size=batch_size),
4623-
data_keys=["observation"],
4625+
data_keys=[LLMEnv._DEFAULT_STR_KEY],
46244626
example_data="a string!",
46254627
)
46264628
else:
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46304632
dataloader=self.DummyTensorDataLoader(
46314633
batch_size=batch_size, padding=True
46324634
),
4633-
data_keys=["observation"],
4635+
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
46344636
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
46354637
stack_method=stack_method,
46364638
)
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46404642
if batched:
46414643
td = env.reset(TensorDict(batch_size=[3]))
46424644
env.check_env_specs(break_when_any_done="both", tensordict=td)
4643-
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
4645+
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46444646
else:
46454647
env.check_env_specs(break_when_any_done="both")
46464648

@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
46634665
if str2str:
46644666
kwargs = {
46654667
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4666-
"data_keys": ["observation"],
4668+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
46674669
"example_data": "a string!",
46684670
}
46694671
else:
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
46734675
"dataloader": self.DummyTensorDataLoader(
46744676
padding=True, batch_size=batch_size
46754677
),
4676-
"data_keys": ["observation"],
4678+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
46774679
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
46784680
"stack_method": stack_method,
46794681
}
4680-
kwargs.update({"str2str": str2str, "device": device})
4682+
kwargs.update(
4683+
{
4684+
"str2str": str2str,
4685+
"device": device,
4686+
"has_attention": False,
4687+
"no_stack": False,
4688+
}
4689+
)
46814690
env = LLMEnv.from_dataloader(**kwargs)
46824691
assert not env.batch_locked
46834692
if batched:
@@ -4690,46 +4699,64 @@ def test_llm_from_dataloader(
46904699
def policy(td):
46914700
if str2str:
46924701
if not td.shape:
4693-
td["action"] = "<nothing>"
4702+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
46944703
else:
4695-
td["action"] = NonTensorStack(
4704+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
46964705
*["<nothing>" for _ in range(td.shape[0])]
46974706
)
46984707
else:
4699-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4708+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
4709+
td.shape + (1,), dtype=torch.int64
4710+
)
47004711
return td
47014712

47024713
if batched:
47034714
# Tell the env that we want 3 sub-envs
47044715
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47054716
assert r.ndim == 2
47064717
if str2str:
4707-
assert isinstance(r[0, 0]["observation"], str)
4708-
assert isinstance(r[0, 1]["observation"], str)
4718+
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
4719+
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
47094720
assert (
4710-
r[0, 0]["observation"]
4711-
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4721+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4722+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4723+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4724+
]
47124725
)
47134726
assert (
4714-
r[0, 1]["observation"]
4715-
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4727+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4728+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4729+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4730+
]
47164731
)
47174732
assert (
4718-
r[-1, 0]["observation"]
4719-
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4733+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4734+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4735+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY])
4736+
]
47204737
)
47214738
assert (
4722-
r[-1, 1]["observation"]
4723-
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4739+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4740+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4741+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY])
4742+
]
47244743
)
47254744
else:
4726-
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4727-
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
47284745
assert (
4729-
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4746+
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4747+
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4748+
).all()
4749+
assert (
4750+
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4751+
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47304752
).all()
47314753
assert (
4732-
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4754+
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4755+
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4756+
).all()
4757+
assert (
4758+
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4759+
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47334760
).all()
47344761
else:
47354762
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
@@ -4755,7 +4782,7 @@ def test_llm_from_dataloader_repeats(
47554782
if str2str:
47564783
kwargs = {
47574784
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4758-
"data_keys": ["observation"],
4785+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
47594786
"example_data": "a string!",
47604787
"repeats": repeats,
47614788
}
@@ -4766,12 +4793,19 @@ def test_llm_from_dataloader_repeats(
47664793
"dataloader": self.DummyTensorDataLoader(
47674794
padding=True, batch_size=batch_size
47684795
),
4769-
"data_keys": ["observation"],
4796+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
47704797
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
47714798
"stack_method": stack_method,
47724799
"repeats": repeats,
47734800
}
4774-
kwargs.update({"str2str": str2str, "device": device})
4801+
kwargs.update(
4802+
{
4803+
"str2str": str2str,
4804+
"device": device,
4805+
"has_attention": False,
4806+
"no_stack": False,
4807+
}
4808+
)
47754809
env = LLMEnv.from_dataloader(**kwargs)
47764810
assert env.transform.repeats == repeats
47774811

@@ -4781,13 +4815,15 @@ def test_llm_from_dataloader_repeats(
47814815
def policy(td):
47824816
if str2str:
47834817
if not td.shape:
4784-
td["action"] = "<nothing>"
4818+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "<nothing>"
47854819
else:
4786-
td["action"] = NonTensorStack(
4820+
td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack(
47874821
*["<nothing>" for _ in range(td.shape[0])]
47884822
)
47894823
else:
4790-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4824+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
4825+
td.shape + (1,), dtype=torch.int64
4826+
)
47914827
return td
47924828

47934829
if batched:
@@ -4803,34 +4839,58 @@ def policy(td):
48034839
r_reset = r[..., ::max_steps]
48044840
if not batched:
48054841
if str2str:
4806-
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4807-
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4808-
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4842+
assert (
4843+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4844+
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4845+
)
4846+
assert (
4847+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4848+
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4849+
)
4850+
assert (
4851+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4852+
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4853+
)
48094854
else:
48104855
assert (
4811-
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4856+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4857+
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
48124858
).all()
48134859
assert (
4814-
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4860+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4861+
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
48154862
).all()
48164863
assert (
4817-
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4864+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4865+
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
48184866
).any()
48194867
else:
48204868
# When batched, each block contains the 3 reset packs
48214869
if str2str:
4822-
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4823-
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4824-
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4870+
assert (
4871+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4872+
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
4873+
)
4874+
assert (
4875+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4876+
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
4877+
)
4878+
assert (
4879+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4880+
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
4881+
)
48254882
else:
48264883
assert (
4827-
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4884+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4885+
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48284886
).all()
48294887
assert (
4830-
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4888+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4889+
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48314890
).all()
48324891
assert (
4833-
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4892+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4893+
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
48344894
).any()
48354895

48364896
@pytest.mark.parametrize(
@@ -4864,7 +4924,7 @@ def test_done_and_reward(
48644924
if str2str:
48654925
kwargs = {
48664926
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4867-
"data_keys": ["observation"],
4927+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
48684928
"example_data": "a string!",
48694929
"repeats": repeats,
48704930
"assign_reward": assign_reward,
@@ -4877,20 +4937,27 @@ def test_done_and_reward(
48774937
"dataloader": self.DummyTensorDataLoader(
48784938
padding=True, batch_size=batch_size
48794939
),
4880-
"data_keys": ["observation"],
4940+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
48814941
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
48824942
"stack_method": stack_method,
48834943
"repeats": repeats,
48844944
"assign_reward": assign_reward,
48854945
"assign_done": assign_done,
48864946
}
4887-
kwargs.update({"str2str": str2str, "device": device})
4947+
kwargs.update(
4948+
{
4949+
"str2str": str2str,
4950+
"device": device,
4951+
"has_attention": False,
4952+
"no_stack": False,
4953+
}
4954+
)
48884955
env = LLMEnv.from_dataloader(**kwargs)
48894956
# We want to make sure that transforms that rely on the done state work appropriately
48904957
env.append_transform(StepCounter(max_steps=10))
48914958

48924959
def policy(td):
4893-
td["action"] = torch.ones(
4960+
td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones(
48944961
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
48954962
)
48964963
return td
@@ -4907,7 +4974,6 @@ def policy(td):
49074974
if assign_done:
49084975
assert "terminated" in r
49094976
assert "done" in r
4910-
print(r)
49114977

49124978

49134979
if __name__ == "__main__":

0 commit comments

Comments
 (0)