Skip to content

Commit be827e5

Browse files
committed
[Feature] LLMEnv and DataLoadingPrimer
ghstack-source-id: 1ad4943 Pull Request resolved: #2818
1 parent a40da99 commit be827e5

File tree

7 files changed

+786
-13
lines changed

7 files changed

+786
-13
lines changed

docs/source/reference/envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ TorchRL offers a series of custom built-in environments.
440440
ChessEnv
441441
PendulumEnv
442442
TicTacToeEnv
443+
LLMEnv
443444
LLMHashingEnv
444445

445446

@@ -1033,6 +1034,7 @@ to be able to create this other composition:
10331034
Compose
10341035
ConditionalSkip
10351036
Crop
1037+
DataLoadingPrimer
10361038
DTypeCastTransform
10371039
DeviceCastTransform
10381040
DiscreteActionProjection

test/test_env.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pickle
1313
import random
1414
import re
15+
import string
1516
from collections import defaultdict
1617
from functools import partial
1718
from sys import platform
@@ -43,9 +44,11 @@
4344
CatTensors,
4445
ChessEnv,
4546
ConditionalSkip,
47+
DataLoadingPrimer,
4648
DoubleToFloat,
4749
EnvBase,
4850
EnvCreator,
51+
LLMEnv,
4952
LLMHashingEnv,
5053
ParallelEnv,
5154
PendulumEnv,
@@ -57,6 +60,7 @@
5760
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
5861
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
5962
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
63+
from torchrl.envs.transforms.rlhf import as_padded_tensor
6064
from torchrl.envs.transforms.transforms import (
6165
AutoResetEnv,
6266
AutoResetTransform,
@@ -95,7 +99,7 @@
9599

96100
try:
97101
this_dir = os.path.dirname(os.path.realpath(__file__))
98-
with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file:
102+
with open(os.path.join(this_dir, "configs", "atari.yaml")) as file:
99103
atari_confs = yaml.load(file, Loader=yaml.FullLoader)
100104
_atari_found = True
101105
except FileNotFoundError:
@@ -4503,6 +4507,141 @@ def test_skipping_history_env_collector(self, device_env, collector_cls):
45034507
count += 1
45044508

45054509

4510+
class TestLLMEnv:
4511+
@pytest.fixture(scope="class", autouse=True)
4512+
def set_capture(self):
4513+
with set_capture_non_tensor_stack(False):
4514+
yield None
4515+
return
4516+
4517+
class DummyDataLoader:
4518+
def __init__(self, batch_size=0):
4519+
self.batch_size = batch_size
4520+
4521+
def generate_random_string(self, length=10):
4522+
"""Generate a random string of a given length."""
4523+
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
4524+
4525+
def __iter__(self):
4526+
return self
4527+
4528+
def __next__(self):
4529+
if self.batch_size == 0:
4530+
return self.generate_random_string()
4531+
else:
4532+
return [self.generate_random_string() for _ in range(self.batch_size)]
4533+
4534+
class DummyTensorDataLoader:
4535+
def __init__(self, batch_size=0, max_length=10, padding=False):
4536+
self.batch_size = batch_size
4537+
self.max_length = max_length
4538+
self.padding = padding
4539+
4540+
def generate_random_tensor(self):
4541+
"""Generate a tensor of random int64 values."""
4542+
length = random.randint(1, self.max_length)
4543+
return torch.tensor(
4544+
[random.randint(0, 100) for _ in range(length)], dtype=torch.int64
4545+
)
4546+
4547+
def pad_tensor(self, tensor):
4548+
"""Pad a tensor to the maximum length."""
4549+
padding_length = self.max_length - len(tensor)
4550+
return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
4551+
4552+
def __iter__(self):
4553+
return self
4554+
4555+
def __next__(self):
4556+
if self.batch_size == 0:
4557+
tensor = self.generate_random_tensor()
4558+
return self.pad_tensor(tensor) if self.padding else tensor
4559+
else:
4560+
tensors = [
4561+
self.generate_random_tensor() for _ in range(self.batch_size)
4562+
]
4563+
if self.padding:
4564+
tensors = [self.pad_tensor(tensor) for tensor in tensors]
4565+
return torch.stack(tensors)
4566+
else:
4567+
return tensors
4568+
4569+
@pytest.mark.parametrize(
4570+
"str2str,stack_method",
4571+
[
4572+
[True, None],
4573+
[False, "as_padded_tensor"],
4574+
# TODO: a bit experimental, fails with check_env_specs
4575+
# [False, "as_nested_tensor"],
4576+
[False, None],
4577+
],
4578+
)
4579+
@pytest.mark.parametrize("batched", [True, False])
4580+
@pytest.mark.parametrize("device", [None, "cpu"])
4581+
def test_llm_env(self, str2str, batched, stack_method, device):
4582+
env = LLMEnv(str2str=str2str, device=device)
4583+
if str2str:
4584+
primer = DataLoadingPrimer(
4585+
dataloader=self.DummyDataLoader(),
4586+
data_keys=["observation"],
4587+
example_data="a string!",
4588+
)
4589+
else:
4590+
if stack_method is None:
4591+
stack_method = as_padded_tensor
4592+
primer = DataLoadingPrimer(
4593+
dataloader=self.DummyTensorDataLoader(padding=True),
4594+
data_keys=["observation"],
4595+
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
4596+
stack_method=stack_method,
4597+
)
4598+
assert not env.batch_locked
4599+
env = env.append_transform(primer)
4600+
assert not env.batch_locked
4601+
if batched:
4602+
td = env.reset(TensorDict(batch_size=[3]))
4603+
env.check_env_specs(break_when_any_done="both", tensordict=td)
4604+
else:
4605+
env.check_env_specs(break_when_any_done="both")
4606+
4607+
@pytest.mark.parametrize(
4608+
"str2str,stack_method",
4609+
[
4610+
[True, None],
4611+
[False, "as_padded_tensor"],
4612+
# TODO: a bit experimental, fails with check_env_specs
4613+
# [False, "as_nested_tensor"],
4614+
[False, None],
4615+
],
4616+
)
4617+
@pytest.mark.parametrize("batched", [True, False])
4618+
@pytest.mark.parametrize("device", [None, "cpu"])
4619+
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
4620+
if str2str:
4621+
kwargs = {
4622+
"dataloader": self.DummyDataLoader(),
4623+
"data_keys": ["observation"],
4624+
"example_data": "a string!",
4625+
}
4626+
else:
4627+
if stack_method is None:
4628+
stack_method = as_padded_tensor
4629+
kwargs = {
4630+
"dataloader": self.DummyTensorDataLoader(padding=True),
4631+
"data_keys": ["observation"],
4632+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4633+
"stack_method": stack_method,
4634+
}
4635+
kwargs.update({"str2str": str2str, "device": device})
4636+
env = LLMEnv.from_dataloader(**kwargs)
4637+
assert not env.batch_locked
4638+
if batched:
4639+
td = env.reset(TensorDict(batch_size=[3]))
4640+
env.check_env_specs(break_when_any_done="both", tensordict=td)
4641+
else:
4642+
env.check_env_specs(break_when_any_done="both")
4643+
4644+
45064645
if __name__ == "__main__":
45074646
args, unknown = argparse.ArgumentParser().parse_known_args()
45084647
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .batched_envs import ParallelEnv, SerialEnv
77
from .common import EnvBase, EnvMetaData, make_tensordict
8-
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
8+
from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
99
from .env_creator import env_creator, EnvCreator, get_env_metadata
1010
from .gym_like import default_info_dict_reader, GymLikeEnv
1111
from .libs import (
@@ -58,6 +58,7 @@
5858
Compose,
5959
ConditionalSkip,
6060
Crop,
61+
DataLoadingPrimer,
6162
DeviceCastTransform,
6263
DiscreteActionProjection,
6364
DoubleToFloat,

torchrl/envs/custom/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .chess import ChessEnv
7-
from .llm import LLMHashingEnv
7+
from .llm import LLMEnv, LLMHashingEnv
88
from .pendulum import PendulumEnv
99
from .tictactoeenv import TicTacToeEnv

0 commit comments

Comments
 (0)