Skip to content

Commit 4664712

Browse files
committed
[Feature] LLMEnv and DataLoadingPrimer
ghstack-source-id: 0785280 Pull Request resolved: #2818
1 parent 034f248 commit 4664712

File tree

10 files changed

+781
-24
lines changed

10 files changed

+781
-24
lines changed

docs/source/reference/envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ TorchRL offers a series of custom built-in environments.
440440
ChessEnv
441441
PendulumEnv
442442
TicTacToeEnv
443+
LLMEnv
443444
LLMHashingEnv
444445

445446

@@ -1033,6 +1034,7 @@ to be able to create this other composition:
10331034
Compose
10341035
ConditionalSkip
10351036
Crop
1037+
DataLoadingPrimer
10361038
DTypeCastTransform
10371039
DeviceCastTransform
10381040
DiscreteActionProjection

test/test_env.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pickle
1313
import random
1414
import re
15+
import string
1516
from collections import defaultdict
1617
from functools import partial
1718
from sys import platform
@@ -43,9 +44,11 @@
4344
CatTensors,
4445
ChessEnv,
4546
ConditionalSkip,
47+
DataLoadingPrimer,
4648
DoubleToFloat,
4749
EnvBase,
4850
EnvCreator,
51+
LLMEnv,
4952
LLMHashingEnv,
5053
ParallelEnv,
5154
PendulumEnv,
@@ -66,6 +69,7 @@
6669
TransformedEnv,
6770
UnsqueezeTransform,
6871
)
72+
from torchrl.envs.transforms.rlhf import as_padded_tensor
6973
from torchrl.envs.utils import (
7074
_StepMDP,
7175
_terminated_or_truncated,
@@ -97,7 +101,7 @@
97101

98102
try:
99103
this_dir = os.path.dirname(os.path.realpath(__file__))
100-
with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file:
104+
with open(os.path.join(this_dir, "configs", "atari.yaml")) as file:
101105
atari_confs = yaml.load(file, Loader=yaml.FullLoader)
102106
_atari_found = True
103107
except FileNotFoundError:
@@ -4526,6 +4530,141 @@ def test_skipping_history_env_collector(self, device_env, collector_cls):
45264530
count += 1
45274531

45284532

4533+
class TestLLMEnv:
4534+
@pytest.fixture(scope="class", autouse=True)
4535+
def set_capture(self):
4536+
with set_capture_non_tensor_stack(False):
4537+
yield None
4538+
return
4539+
4540+
class DummyDataLoader:
4541+
def __init__(self, batch_size=0):
4542+
self.batch_size = batch_size
4543+
4544+
def generate_random_string(self, length=10):
4545+
"""Generate a random string of a given length."""
4546+
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
4547+
4548+
def __iter__(self):
4549+
return self
4550+
4551+
def __next__(self):
4552+
if self.batch_size == 0:
4553+
return self.generate_random_string()
4554+
else:
4555+
return [self.generate_random_string() for _ in range(self.batch_size)]
4556+
4557+
class DummyTensorDataLoader:
4558+
def __init__(self, batch_size=0, max_length=10, padding=False):
4559+
self.batch_size = batch_size
4560+
self.max_length = max_length
4561+
self.padding = padding
4562+
4563+
def generate_random_tensor(self):
4564+
"""Generate a tensor of random int64 values."""
4565+
length = random.randint(1, self.max_length)
4566+
return torch.tensor(
4567+
[random.randint(0, 100) for _ in range(length)], dtype=torch.int64
4568+
)
4569+
4570+
def pad_tensor(self, tensor):
4571+
"""Pad a tensor to the maximum length."""
4572+
padding_length = self.max_length - len(tensor)
4573+
return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
4574+
4575+
def __iter__(self):
4576+
return self
4577+
4578+
def __next__(self):
4579+
if self.batch_size == 0:
4580+
tensor = self.generate_random_tensor()
4581+
return self.pad_tensor(tensor) if self.padding else tensor
4582+
else:
4583+
tensors = [
4584+
self.generate_random_tensor() for _ in range(self.batch_size)
4585+
]
4586+
if self.padding:
4587+
tensors = [self.pad_tensor(tensor) for tensor in tensors]
4588+
return torch.stack(tensors)
4589+
else:
4590+
return tensors
4591+
4592+
@pytest.mark.parametrize(
4593+
"str2str,stack_method",
4594+
[
4595+
[True, None],
4596+
[False, "as_padded_tensor"],
4597+
# TODO: a bit experimental, fails with check_env_specs
4598+
# [False, "as_nested_tensor"],
4599+
[False, None],
4600+
],
4601+
)
4602+
@pytest.mark.parametrize("batched", [True, False])
4603+
@pytest.mark.parametrize("device", [None, "cpu"])
4604+
def test_llm_env(self, str2str, batched, stack_method, device):
4605+
env = LLMEnv(str2str=str2str, device=device)
4606+
if str2str:
4607+
primer = DataLoadingPrimer(
4608+
dataloader=self.DummyDataLoader(),
4609+
data_keys=["observation"],
4610+
example_data="a string!",
4611+
)
4612+
else:
4613+
if stack_method is None:
4614+
stack_method = as_padded_tensor
4615+
primer = DataLoadingPrimer(
4616+
dataloader=self.DummyTensorDataLoader(padding=True),
4617+
data_keys=["observation"],
4618+
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
4619+
stack_method=stack_method,
4620+
)
4621+
assert not env.batch_locked
4622+
env = env.append_transform(primer)
4623+
assert not env.batch_locked
4624+
if batched:
4625+
td = env.reset(TensorDict(batch_size=[3]))
4626+
env.check_env_specs(break_when_any_done="both", tensordict=td)
4627+
else:
4628+
env.check_env_specs(break_when_any_done="both")
4629+
4630+
@pytest.mark.parametrize(
4631+
"str2str,stack_method",
4632+
[
4633+
[True, None],
4634+
[False, "as_padded_tensor"],
4635+
# TODO: a bit experimental, fails with check_env_specs
4636+
# [False, "as_nested_tensor"],
4637+
[False, None],
4638+
],
4639+
)
4640+
@pytest.mark.parametrize("batched", [True, False])
4641+
@pytest.mark.parametrize("device", [None, "cpu"])
4642+
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
4643+
if str2str:
4644+
kwargs = {
4645+
"dataloader": self.DummyDataLoader(),
4646+
"data_keys": ["observation"],
4647+
"example_data": "a string!",
4648+
}
4649+
else:
4650+
if stack_method is None:
4651+
stack_method = as_padded_tensor
4652+
kwargs = {
4653+
"dataloader": self.DummyTensorDataLoader(padding=True),
4654+
"data_keys": ["observation"],
4655+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4656+
"stack_method": stack_method,
4657+
}
4658+
kwargs.update({"str2str": str2str, "device": device})
4659+
env = LLMEnv.from_dataloader(**kwargs)
4660+
assert not env.batch_locked
4661+
if batched:
4662+
td = env.reset(TensorDict(batch_size=[3]))
4663+
env.check_env_specs(break_when_any_done="both", tensordict=td)
4664+
else:
4665+
env.check_env_specs(break_when_any_done="both")
4666+
4667+
45294668
if __name__ == "__main__":
45304669
args, unknown = argparse.ArgumentParser().parse_known_args()
45314670
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
Dict,
2323
Generic,
2424
List,
25-
Optional,
2625
overload,
2726
Sequence,
2827
Tuple,

torchrl/envs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .batched_envs import ParallelEnv, SerialEnv
77
from .common import EnvBase, EnvMetaData, make_tensordict
8-
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
8+
from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
99
from .env_creator import env_creator, EnvCreator, get_env_metadata
1010
from .gym_like import default_info_dict_reader, GymLikeEnv
1111
from .libs import (
@@ -58,6 +58,7 @@
5858
Compose,
5959
ConditionalSkip,
6060
Crop,
61+
DataLoadingPrimer,
6162
DeviceCastTransform,
6263
DiscreteActionProjection,
6364
DoubleToFloat,

torchrl/envs/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import deepcopy
1111
from functools import partial, wraps
12-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
12+
from typing import Any, Callable, Iterator
1313

1414
import numpy as np
1515
import torch

torchrl/envs/custom/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .chess import ChessEnv
7-
from .llm import LLMHashingEnv
7+
from .llm import LLMEnv, LLMHashingEnv
88
from .pendulum import PendulumEnv
99
from .tictactoeenv import TicTacToeEnv

0 commit comments

Comments
 (0)