Skip to content

Commit 98fdba9

Browse files
committed
[Feature] transformers policy
ghstack-source-id: c2349a8ddff03e4fcd55c07d3dd1a64592dad186 Pull Request resolved: #2825
1 parent 115283f commit 98fdba9

File tree

11 files changed

+507
-142
lines changed

11 files changed

+507
-142
lines changed

Diff for: test/test_env.py

+113-46
Original file line numberDiff line numberDiff line change
@@ -4616,11 +4616,13 @@ def __next__(self):
46164616
@pytest.mark.parametrize("batch_size", [0, 4])
46174617
@pytest.mark.parametrize("device", [None, "cpu"])
46184618
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4619-
env = LLMEnv(str2str=str2str, device=device)
4619+
env = LLMEnv(
4620+
str2str=str2str, device=device, has_attention=False, no_stack=False
4621+
)
46204622
if str2str:
46214623
primer = DataLoadingPrimer(
46224624
dataloader=self.DummyDataLoader(batch_size=batch_size),
4623-
data_keys=["observation"],
4625+
data_keys=[LLMEnv._DEFAULT_STR_KEY],
46244626
example_data="a string!",
46254627
)
46264628
else:
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46304632
dataloader=self.DummyTensorDataLoader(
46314633
batch_size=batch_size, padding=True
46324634
),
4633-
data_keys=["observation"],
4635+
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
46344636
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
46354637
stack_method=stack_method,
46364638
)
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46404642
if batched:
46414643
td = env.reset(TensorDict(batch_size=[3]))
46424644
env.check_env_specs(break_when_any_done="both", tensordict=td)
4643-
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
4645+
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46444646
else:
46454647
env.check_env_specs(break_when_any_done="both")
46464648

@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
46634665
if str2str:
46644666
kwargs = {
46654667
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4666-
"data_keys": ["observation"],
4668+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
46674669
"example_data": "a string!",
46684670
}
46694671
else:
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
46734675
"dataloader": self.DummyTensorDataLoader(
46744676
padding=True, batch_size=batch_size
46754677
),
4676-
"data_keys": ["observation"],
4678+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
46774679
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
46784680
"stack_method": stack_method,
46794681
}
4680-
kwargs.update({"str2str": str2str, "device": device})
4682+
kwargs.update(
4683+
{
4684+
"str2str": str2str,
4685+
"device": device,
4686+
"has_attention": False,
4687+
"no_stack": False,
4688+
}
4689+
)
46814690
env = LLMEnv.from_dataloader(**kwargs)
46824691
assert not env.batch_locked
46834692
if batched:
@@ -4690,46 +4699,64 @@ def test_llm_from_dataloader(
46904699
def policy(td):
46914700
if str2str:
46924701
if not td.shape:
4693-
td["action"] = "<nothing>"
4702+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
46944703
else:
4695-
td["action"] = NonTensorStack(
4704+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
46964705
*["<nothing>" for _ in range(td.shape[0])]
46974706
)
46984707
else:
4699-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4708+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4709+
td.shape + (1,), dtype=torch.int64
4710+
)
47004711
return td
47014712

47024713
if batched:
47034714
# Tell the env that we want 3 sub-envs
47044715
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47054716
assert r.ndim == 2
47064717
if str2str:
4707-
assert isinstance(r[0, 0]["observation"], str)
4708-
assert isinstance(r[0, 1]["observation"], str)
4718+
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
4719+
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
47094720
assert (
4710-
r[0, 0]["observation"]
4711-
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4721+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4722+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4723+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_KEY])
4724+
]
47124725
)
47134726
assert (
4714-
r[0, 1]["observation"]
4715-
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4727+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4728+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4729+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_KEY])
4730+
]
47164731
)
47174732
assert (
4718-
r[-1, 0]["observation"]
4719-
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4733+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4734+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4735+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_KEY])
4736+
]
47204737
)
47214738
assert (
4722-
r[-1, 1]["observation"]
4723-
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4739+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4740+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4741+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_KEY])
4742+
]
47244743
)
47254744
else:
4726-
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4727-
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
47284745
assert (
4729-
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4746+
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4747+
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4748+
).all()
4749+
assert (
4750+
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4751+
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47304752
).all()
47314753
assert (
4732-
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4754+
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4755+
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4756+
).all()
4757+
assert (
4758+
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4759+
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47334760
).all()
47344761
else:
47354762
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
@@ -4755,7 +4782,7 @@ def test_llm_from_dataloader_repeats(
47554782
if str2str:
47564783
kwargs = {
47574784
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4758-
"data_keys": ["observation"],
4785+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
47594786
"example_data": "a string!",
47604787
"repeats": repeats,
47614788
}
@@ -4766,12 +4793,19 @@ def test_llm_from_dataloader_repeats(
47664793
"dataloader": self.DummyTensorDataLoader(
47674794
padding=True, batch_size=batch_size
47684795
),
4769-
"data_keys": ["observation"],
4796+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
47704797
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
47714798
"stack_method": stack_method,
47724799
"repeats": repeats,
47734800
}
4774-
kwargs.update({"str2str": str2str, "device": device})
4801+
kwargs.update(
4802+
{
4803+
"str2str": str2str,
4804+
"device": device,
4805+
"has_attention": False,
4806+
"no_stack": False,
4807+
}
4808+
)
47754809
env = LLMEnv.from_dataloader(**kwargs)
47764810
assert env.transform.repeats == repeats
47774811

@@ -4781,13 +4815,15 @@ def test_llm_from_dataloader_repeats(
47814815
def policy(td):
47824816
if str2str:
47834817
if not td.shape:
4784-
td["action"] = "<nothing>"
4818+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
47854819
else:
4786-
td["action"] = NonTensorStack(
4820+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
47874821
*["<nothing>" for _ in range(td.shape[0])]
47884822
)
47894823
else:
4790-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4824+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4825+
td.shape + (1,), dtype=torch.int64
4826+
)
47914827
return td
47924828

47934829
if batched:
@@ -4803,34 +4839,58 @@ def policy(td):
48034839
r_reset = r[..., ::max_steps]
48044840
if not batched:
48054841
if str2str:
4806-
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4807-
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4808-
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4842+
assert (
4843+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4844+
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4845+
)
4846+
assert (
4847+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4848+
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4849+
)
4850+
assert (
4851+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4852+
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4853+
)
48094854
else:
48104855
assert (
4811-
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4856+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4857+
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
48124858
).all()
48134859
assert (
4814-
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4860+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4861+
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
48154862
).all()
48164863
assert (
4817-
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4864+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4865+
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
48184866
).any()
48194867
else:
48204868
# When batched, each block contains the 3 reset packs
48214869
if str2str:
4822-
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4823-
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4824-
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4870+
assert (
4871+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4872+
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
4873+
)
4874+
assert (
4875+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4876+
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
4877+
)
4878+
assert (
4879+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4880+
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
4881+
)
48254882
else:
48264883
assert (
4827-
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4884+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4885+
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48284886
).all()
48294887
assert (
4830-
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4888+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4889+
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
48314890
).all()
48324891
assert (
4833-
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4892+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4893+
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
48344894
).any()
48354895

48364896
@pytest.mark.parametrize(
@@ -4864,7 +4924,7 @@ def test_done_and_reward(
48644924
if str2str:
48654925
kwargs = {
48664926
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4867-
"data_keys": ["observation"],
4927+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
48684928
"example_data": "a string!",
48694929
"repeats": repeats,
48704930
"assign_reward": assign_reward,
@@ -4877,20 +4937,27 @@ def test_done_and_reward(
48774937
"dataloader": self.DummyTensorDataLoader(
48784938
padding=True, batch_size=batch_size
48794939
),
4880-
"data_keys": ["observation"],
4940+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
48814941
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
48824942
"stack_method": stack_method,
48834943
"repeats": repeats,
48844944
"assign_reward": assign_reward,
48854945
"assign_done": assign_done,
48864946
}
4887-
kwargs.update({"str2str": str2str, "device": device})
4947+
kwargs.update(
4948+
{
4949+
"str2str": str2str,
4950+
"device": device,
4951+
"has_attention": False,
4952+
"no_stack": False,
4953+
}
4954+
)
48884955
env = LLMEnv.from_dataloader(**kwargs)
48894956
# We want to make sure that transforms that rely on the done state work appropriately
48904957
env.append_transform(StepCounter(max_steps=10))
48914958

48924959
def policy(td):
4893-
td["action"] = torch.ones(
4960+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
48944961
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
48954962
)
48964963
return td

Diff for: torchrl/data/postprocs/postprocs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch import nn
1313

1414

15-
1615
def _get_reward(
1716
gamma: float,
1817
reward: torch.Tensor,
@@ -367,6 +366,7 @@ def __init__(
367366
discount: float = 1.0,
368367
):
369368
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:

Diff for: torchrl/data/replay_buffers/storages.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,10 @@ def _collate_id(x):
15361536

15371537

15381538
def _get_default_collate(storage, _is_tensordict=False):
1539-
if isinstance(storage, ListStorage):
1540-
return _stack_anything
1541-
elif isinstance(storage, TensorStorage):
1539+
if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage):
15421540
return _collate_id
1541+
elif isinstance(storage, ListStorage):
1542+
return _stack_anything
15431543
else:
15441544
raise NotImplementedError(
15451545
f"Could not find a default collate_fn for storage {type(storage)}."

0 commit comments

Comments
 (0)