Skip to content

[Feature] LLMEnv and DataLoadingPrimer #2818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ TorchRL offers a series of custom built-in environments.
ChessEnv
PendulumEnv
TicTacToeEnv
LLMEnv
LLMHashingEnv


Expand Down Expand Up @@ -1033,6 +1034,7 @@ to be able to create this other composition:
Compose
ConditionalSkip
Crop
DataLoadingPrimer
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
141 changes: 140 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pickle
import random
import re
import string
from collections import defaultdict
from functools import partial
from sys import platform
Expand Down Expand Up @@ -43,9 +44,11 @@
CatTensors,
ChessEnv,
ConditionalSkip,
DataLoadingPrimer,
DoubleToFloat,
EnvBase,
EnvCreator,
LLMEnv,
LLMHashingEnv,
ParallelEnv,
PendulumEnv,
Expand All @@ -57,6 +60,7 @@
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.transforms.rlhf import as_padded_tensor
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
Expand Down Expand Up @@ -95,7 +99,7 @@

try:
this_dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file:
with open(os.path.join(this_dir, "configs", "atari.yaml")) as file:
atari_confs = yaml.load(file, Loader=yaml.FullLoader)
_atari_found = True
except FileNotFoundError:
Expand Down Expand Up @@ -4503,6 +4507,141 @@ def test_skipping_history_env_collector(self, device_env, collector_cls):
count += 1


class TestLLMEnv:
@pytest.fixture(scope="class", autouse=True)
def set_capture(self):
with set_capture_non_tensor_stack(False):
yield None
return

class DummyDataLoader:
def __init__(self, batch_size=0):
self.batch_size = batch_size

def generate_random_string(self, length=10):
"""Generate a random string of a given length."""
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))

def __iter__(self):
return self

def __next__(self):
if self.batch_size == 0:
return self.generate_random_string()
else:
return [self.generate_random_string() for _ in range(self.batch_size)]

class DummyTensorDataLoader:
def __init__(self, batch_size=0, max_length=10, padding=False):
self.batch_size = batch_size
self.max_length = max_length
self.padding = padding

def generate_random_tensor(self):
"""Generate a tensor of random int64 values."""
length = random.randint(1, self.max_length)
return torch.tensor(
[random.randint(0, 100) for _ in range(length)], dtype=torch.int64
)

def pad_tensor(self, tensor):
"""Pad a tensor to the maximum length."""
padding_length = self.max_length - len(tensor)
return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))

def __iter__(self):
return self

def __next__(self):
if self.batch_size == 0:
tensor = self.generate_random_tensor()
return self.pad_tensor(tensor) if self.padding else tensor
else:
tensors = [
self.generate_random_tensor() for _ in range(self.batch_size)
]
if self.padding:
tensors = [self.pad_tensor(tensor) for tensor in tensors]
return torch.stack(tensors)
else:
return tensors

@pytest.mark.parametrize(
"str2str,stack_method",
[
[True, None],
[False, "as_padded_tensor"],
# TODO: a bit experimental, fails with check_env_specs
# [False, "as_nested_tensor"],
[False, None],
],
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_env(self, str2str, batched, stack_method, device):
env = LLMEnv(str2str=str2str, device=device)
if str2str:
primer = DataLoadingPrimer(
dataloader=self.DummyDataLoader(),
data_keys=["observation"],
example_data="a string!",
)
else:
if stack_method is None:
stack_method = as_padded_tensor
primer = DataLoadingPrimer(
dataloader=self.DummyTensorDataLoader(padding=True),
data_keys=["observation"],
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
stack_method=stack_method,
)
assert not env.batch_locked
env = env.append_transform(primer)
assert not env.batch_locked
if batched:
td = env.reset(TensorDict(batch_size=[3]))
env.check_env_specs(break_when_any_done="both", tensordict=td)
else:
env.check_env_specs(break_when_any_done="both")

@pytest.mark.parametrize(
"str2str,stack_method",
[
[True, None],
[False, "as_padded_tensor"],
# TODO: a bit experimental, fails with check_env_specs
# [False, "as_nested_tensor"],
[False, None],
],
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(),
"data_keys": ["observation"],
"example_data": "a string!",
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(padding=True),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
}
kwargs.update({"str2str": str2str, "device": device})
env = LLMEnv.from_dataloader(**kwargs)
assert not env.batch_locked
if batched:
td = env.reset(TensorDict(batch_size=[3]))
env.check_env_specs(break_when_any_done="both", tensordict=td)
else:
env.check_env_specs(break_when_any_done="both")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
3 changes: 2 additions & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .env_creator import env_creator, EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
Expand Down Expand Up @@ -58,6 +58,7 @@
Compose,
ConditionalSkip,
Crop,
DataLoadingPrimer,
DeviceCastTransform,
DiscreteActionProjection,
DoubleToFloat,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# LICENSE file in the root directory of this source tree.

from .chess import ChessEnv
from .llm import LLMHashingEnv
from .llm import LLMEnv, LLMHashingEnv
from .pendulum import PendulumEnv
from .tictactoeenv import TicTacToeEnv
Loading
Loading