Skip to content

Commit e978ddb

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent f4713f9 commit e978ddb

File tree

9 files changed

+508
-244
lines changed

9 files changed

+508
-244
lines changed

docs/source/reference/data.rst

+6-5
Original file line numberDiff line numberDiff line change
@@ -1144,16 +1144,17 @@ Utils
11441144
:toctree: generated/
11451145
:template: rl_template.rst
11461146

1147-
MultiStep
1148-
consolidate_spec
1149-
check_no_exclusive_keys
1150-
contains_lazy_spec
1151-
Nested2TED
1147+
DensifyReward
11521148
Flat2TED
11531149
H5Combine
11541150
H5Split
1151+
MultiStep
1152+
Nested2TED
11551153
TED2Flat
11561154
TED2Nested
1155+
check_no_exclusive_keys
1156+
consolidate_spec
1157+
contains_lazy_spec
11571158

11581159
.. currentmodule:: torchrl.envs.transforms.rb_transforms
11591160

test/test_env.py

+98
Original file line numberDiff line numberDiff line change
@@ -4700,6 +4700,104 @@ def policy(td):
47004700
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47014701
assert r.ndim == 1
47024702

4703+
@pytest.mark.parametrize(
4704+
"str2str,stack_method",
4705+
[
4706+
[True, None],
4707+
[False, "as_padded_tensor"],
4708+
# TODO: a bit experimental, fails with check_env_specs
4709+
# [False, "as_nested_tensor"],
4710+
[False, None],
4711+
],
4712+
)
4713+
@pytest.mark.parametrize("batched", [True, False])
4714+
@pytest.mark.parametrize("device", [None, "cpu"])
4715+
@pytest.mark.parametrize("batch_size", [0, 4])
4716+
@pytest.mark.parametrize("repeats", [3])
4717+
def test_llm_from_dataloader_repeats(
4718+
self, str2str, batched, stack_method, device, batch_size, repeats
4719+
):
4720+
if str2str:
4721+
kwargs = {
4722+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4723+
"data_keys": ["observation"],
4724+
"example_data": "a string!",
4725+
"repeats": repeats,
4726+
}
4727+
else:
4728+
if stack_method is None:
4729+
stack_method = as_padded_tensor
4730+
kwargs = {
4731+
"dataloader": self.DummyTensorDataLoader(
4732+
padding=True, batch_size=batch_size
4733+
),
4734+
"data_keys": ["observation"],
4735+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4736+
"stack_method": stack_method,
4737+
"repeats": repeats,
4738+
}
4739+
kwargs.update({"str2str": str2str, "device": device})
4740+
env = LLMEnv.from_dataloader(**kwargs)
4741+
assert env.transform.repeats == repeats
4742+
4743+
max_steps = 3
4744+
env.append_transform(StepCounter(max_steps=max_steps))
4745+
4746+
def policy(td):
4747+
if str2str:
4748+
if not td.shape:
4749+
td["action"] = "<nothing>"
4750+
else:
4751+
td["action"] = NonTensorStack(
4752+
*["<nothing>" for _ in range(td.shape[0])]
4753+
)
4754+
else:
4755+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4756+
return td
4757+
4758+
if batched:
4759+
r = env.rollout(
4760+
100,
4761+
policy,
4762+
tensordict=TensorDict(batch_size=[3]),
4763+
break_when_any_done=False,
4764+
)
4765+
else:
4766+
r = env.rollout(100, policy, break_when_any_done=False)
4767+
# check that r at reset is always the same
4768+
r_reset = r[..., ::max_steps]
4769+
if not batched:
4770+
if str2str:
4771+
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4772+
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4773+
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4774+
else:
4775+
assert (
4776+
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4777+
).all()
4778+
assert (
4779+
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4780+
).all()
4781+
assert (
4782+
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4783+
).any()
4784+
else:
4785+
# When batched, each block contains the 3 reset packs
4786+
if str2str:
4787+
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4788+
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4789+
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4790+
else:
4791+
assert (
4792+
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4793+
).all()
4794+
assert (
4795+
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4796+
).all()
4797+
assert (
4798+
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4799+
).any()
4800+
47034801

47044802
if __name__ == "__main__":
47054803
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)