Skip to content

Commit 115283f

Browse files
committed
[Feature] batch_size, reward, done, attention_key in LLMEnv
ghstack-source-id: b6657fc202e42b25c76b19602e71e1aebd196abf Pull Request resolved: #2824
1 parent 7dcd859 commit 115283f

File tree

11 files changed

+242
-37
lines changed

11 files changed

+242
-37
lines changed

test/test_env.py

+76
Original file line numberDiff line numberDiff line change
@@ -4833,6 +4833,82 @@ def policy(td):
48334833
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
48344834
).any()
48354835

4836+
@pytest.mark.parametrize(
4837+
"str2str,stack_method",
4838+
[
4839+
[True, None],
4840+
[False, "as_padded_tensor"],
4841+
],
4842+
)
4843+
@pytest.mark.parametrize("batched", [True])
4844+
@pytest.mark.parametrize("device", [None])
4845+
@pytest.mark.parametrize("batch_size", [4])
4846+
@pytest.mark.parametrize("repeats", [3])
4847+
@pytest.mark.parametrize(
4848+
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
4849+
)
4850+
def test_done_and_reward(
4851+
self,
4852+
str2str,
4853+
batched,
4854+
stack_method,
4855+
device,
4856+
batch_size,
4857+
repeats,
4858+
assign_reward,
4859+
assign_done,
4860+
):
4861+
with pytest.raises(
4862+
ValueError, match="str2str"
4863+
) if str2str else contextlib.nullcontext():
4864+
if str2str:
4865+
kwargs = {
4866+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4867+
"data_keys": ["observation"],
4868+
"example_data": "a string!",
4869+
"repeats": repeats,
4870+
"assign_reward": assign_reward,
4871+
"assign_done": assign_done,
4872+
}
4873+
else:
4874+
if stack_method is None:
4875+
stack_method = as_padded_tensor
4876+
kwargs = {
4877+
"dataloader": self.DummyTensorDataLoader(
4878+
padding=True, batch_size=batch_size
4879+
),
4880+
"data_keys": ["observation"],
4881+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4882+
"stack_method": stack_method,
4883+
"repeats": repeats,
4884+
"assign_reward": assign_reward,
4885+
"assign_done": assign_done,
4886+
}
4887+
kwargs.update({"str2str": str2str, "device": device})
4888+
env = LLMEnv.from_dataloader(**kwargs)
4889+
# We want to make sure that transforms that rely on the done state work appropriately
4890+
env.append_transform(StepCounter(max_steps=10))
4891+
4892+
def policy(td):
4893+
td["action"] = torch.ones(
4894+
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
4895+
)
4896+
return td
4897+
4898+
if batched:
4899+
r = env.rollout(
4900+
100,
4901+
policy,
4902+
tensordict=TensorDict(batch_size=[3]),
4903+
break_when_any_done=False,
4904+
)
4905+
else:
4906+
r = env.rollout(100, policy, break_when_any_done=False)
4907+
if assign_done:
4908+
assert "terminated" in r
4909+
assert "done" in r
4910+
print(r)
4911+
48364912

48374913
if __name__ == "__main__":
48384914
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/tdstorage.py

-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def __init__(
128128
self.in_keys = query_module.in_keys
129129
if out_keys is not None:
130130
self.out_keys = out_keys
131-
assert not self._has_lazy_out_keys()
132131

133132
self.query_module = query_module
134133
self.index_key = query_module.index_key

torchrl/data/postprocs/postprocs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from tensordict.utils import expand_right
1212
from torch import nn
1313

14-
from torchrl.objectives.value.functional import reward2go
1514

1615

1716
def _get_reward(
@@ -367,13 +366,15 @@ def __init__(
367366
time_dim: int = 2,
368367
discount: float = 1.0,
369368
):
369+
from torchrl.objectives.value.functional import reward2go
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:
373373
reward_key_out = reward_key
374374
self.out_keys = [unravel_key(reward_key_out)]
375375
self.time_dim = time_dim
376376
self.discount = discount
377+
self.reward2go = reward2go
377378

378379
def forward(self, tensordict):
379380
# Get done
@@ -385,6 +386,6 @@ def forward(self, tensordict):
385386
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
386387
f"and done.shape={done.shape}."
387388
)
388-
reward = reward2go(reward, done, time_dim=-2, gamma=self.discount)
389+
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
389390
tensordict.set(("next", self.out_keys[0]), reward)
390391
return tensordict

torchrl/envs/common.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset):
27882788
if reset_value is not None:
27892789
for done_key in done_key_group:
27902790
done_val = tensordict_reset.get(done_key)
2791-
if done_val[reset_value].any() and not self._allow_done_after_reset:
2791+
if (
2792+
done_val.any()
2793+
and done_val[reset_value].any()
2794+
and not self._allow_done_after_reset
2795+
):
27922796
raise RuntimeError(
27932797
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
27942798
)
@@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
35883592
"""
35893593
any_done = self.any_done(tensordict)
35903594
if any_done:
3591-
return self.reset(tensordict, select_reset_only=True)
3595+
tensordict = self.reset(tensordict, select_reset_only=True)
35923596
return tensordict
35933597

35943598
def empty_cache(self):

0 commit comments

Comments
 (0)