|
12 | 12 | import pickle
|
13 | 13 | import random
|
14 | 14 | import re
|
| 15 | +import string |
15 | 16 | from collections import defaultdict
|
16 | 17 | from functools import partial
|
17 | 18 | from sys import platform
|
|
43 | 44 | CatTensors,
|
44 | 45 | ChessEnv,
|
45 | 46 | ConditionalSkip,
|
| 47 | + DataLoadingPrimer, |
46 | 48 | DoubleToFloat,
|
47 | 49 | EnvBase,
|
48 | 50 | EnvCreator,
|
| 51 | + LLMEnv, |
49 | 52 | LLMHashingEnv,
|
50 | 53 | ParallelEnv,
|
51 | 54 | PendulumEnv,
|
|
66 | 69 | TransformedEnv,
|
67 | 70 | UnsqueezeTransform,
|
68 | 71 | )
|
| 72 | +from torchrl.envs.transforms.rlhf import as_padded_tensor |
69 | 73 | from torchrl.envs.utils import (
|
70 | 74 | _StepMDP,
|
71 | 75 | _terminated_or_truncated,
|
|
97 | 101 |
|
98 | 102 | try:
|
99 | 103 | this_dir = os.path.dirname(os.path.realpath(__file__))
|
100 |
| - with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file: |
| 104 | + with open(os.path.join(this_dir, "configs", "atari.yaml")) as file: |
101 | 105 | atari_confs = yaml.load(file, Loader=yaml.FullLoader)
|
102 | 106 | _atari_found = True
|
103 | 107 | except FileNotFoundError:
|
@@ -4526,6 +4530,141 @@ def test_skipping_history_env_collector(self, device_env, collector_cls):
|
4526 | 4530 | count += 1
|
4527 | 4531 |
|
4528 | 4532 |
|
| 4533 | +class TestLLMEnv: |
| 4534 | + @pytest.fixture(scope="class", autouse=True) |
| 4535 | + def set_capture(self): |
| 4536 | + with set_capture_non_tensor_stack(False): |
| 4537 | + yield None |
| 4538 | + return |
| 4539 | + |
| 4540 | + class DummyDataLoader: |
| 4541 | + def __init__(self, batch_size=0): |
| 4542 | + self.batch_size = batch_size |
| 4543 | + |
| 4544 | + def generate_random_string(self, length=10): |
| 4545 | + """Generate a random string of a given length.""" |
| 4546 | + return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) |
| 4547 | + |
| 4548 | + def __iter__(self): |
| 4549 | + return self |
| 4550 | + |
| 4551 | + def __next__(self): |
| 4552 | + if self.batch_size == 0: |
| 4553 | + return self.generate_random_string() |
| 4554 | + else: |
| 4555 | + return [self.generate_random_string() for _ in range(self.batch_size)] |
| 4556 | + |
| 4557 | + class DummyTensorDataLoader: |
| 4558 | + def __init__(self, batch_size=0, max_length=10, padding=False): |
| 4559 | + self.batch_size = batch_size |
| 4560 | + self.max_length = max_length |
| 4561 | + self.padding = padding |
| 4562 | + |
| 4563 | + def generate_random_tensor(self): |
| 4564 | + """Generate a tensor of random int64 values.""" |
| 4565 | + length = random.randint(1, self.max_length) |
| 4566 | + return torch.tensor( |
| 4567 | + [random.randint(0, 100) for _ in range(length)], dtype=torch.int64 |
| 4568 | + ) |
| 4569 | + |
| 4570 | + def pad_tensor(self, tensor): |
| 4571 | + """Pad a tensor to the maximum length.""" |
| 4572 | + padding_length = self.max_length - len(tensor) |
| 4573 | + return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) |
| 4574 | + |
| 4575 | + def __iter__(self): |
| 4576 | + return self |
| 4577 | + |
| 4578 | + def __next__(self): |
| 4579 | + if self.batch_size == 0: |
| 4580 | + tensor = self.generate_random_tensor() |
| 4581 | + return self.pad_tensor(tensor) if self.padding else tensor |
| 4582 | + else: |
| 4583 | + tensors = [ |
| 4584 | + self.generate_random_tensor() for _ in range(self.batch_size) |
| 4585 | + ] |
| 4586 | + if self.padding: |
| 4587 | + tensors = [self.pad_tensor(tensor) for tensor in tensors] |
| 4588 | + return torch.stack(tensors) |
| 4589 | + else: |
| 4590 | + return tensors |
| 4591 | + |
| 4592 | + @pytest.mark.parametrize( |
| 4593 | + "str2str,stack_method", |
| 4594 | + [ |
| 4595 | + [True, None], |
| 4596 | + [False, "as_padded_tensor"], |
| 4597 | + # TODO: a bit experimental, fails with check_env_specs |
| 4598 | + # [False, "as_nested_tensor"], |
| 4599 | + [False, None], |
| 4600 | + ], |
| 4601 | + ) |
| 4602 | + @pytest.mark.parametrize("batched", [True, False]) |
| 4603 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4604 | + def test_llm_env(self, str2str, batched, stack_method, device): |
| 4605 | + env = LLMEnv(str2str=str2str, device=device) |
| 4606 | + if str2str: |
| 4607 | + primer = DataLoadingPrimer( |
| 4608 | + dataloader=self.DummyDataLoader(), |
| 4609 | + data_keys=["observation"], |
| 4610 | + example_data="a string!", |
| 4611 | + ) |
| 4612 | + else: |
| 4613 | + if stack_method is None: |
| 4614 | + stack_method = as_padded_tensor |
| 4615 | + primer = DataLoadingPrimer( |
| 4616 | + dataloader=self.DummyTensorDataLoader(padding=True), |
| 4617 | + data_keys=["observation"], |
| 4618 | + data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], |
| 4619 | + stack_method=stack_method, |
| 4620 | + ) |
| 4621 | + assert not env.batch_locked |
| 4622 | + env = env.append_transform(primer) |
| 4623 | + assert not env.batch_locked |
| 4624 | + if batched: |
| 4625 | + td = env.reset(TensorDict(batch_size=[3])) |
| 4626 | + env.check_env_specs(break_when_any_done="both", tensordict=td) |
| 4627 | + else: |
| 4628 | + env.check_env_specs(break_when_any_done="both") |
| 4629 | + |
| 4630 | + @pytest.mark.parametrize( |
| 4631 | + "str2str,stack_method", |
| 4632 | + [ |
| 4633 | + [True, None], |
| 4634 | + [False, "as_padded_tensor"], |
| 4635 | + # TODO: a bit experimental, fails with check_env_specs |
| 4636 | + # [False, "as_nested_tensor"], |
| 4637 | + [False, None], |
| 4638 | + ], |
| 4639 | + ) |
| 4640 | + @pytest.mark.parametrize("batched", [True, False]) |
| 4641 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4642 | + def test_llm_from_dataloader(self, str2str, batched, stack_method, device): |
| 4643 | + if str2str: |
| 4644 | + kwargs = { |
| 4645 | + "dataloader": self.DummyDataLoader(), |
| 4646 | + "data_keys": ["observation"], |
| 4647 | + "example_data": "a string!", |
| 4648 | + } |
| 4649 | + else: |
| 4650 | + if stack_method is None: |
| 4651 | + stack_method = as_padded_tensor |
| 4652 | + kwargs = { |
| 4653 | + "dataloader": self.DummyTensorDataLoader(padding=True), |
| 4654 | + "data_keys": ["observation"], |
| 4655 | + "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], |
| 4656 | + "stack_method": stack_method, |
| 4657 | + } |
| 4658 | + kwargs.update({"str2str": str2str, "device": device}) |
| 4659 | + env = LLMEnv.from_dataloader(**kwargs) |
| 4660 | + assert not env.batch_locked |
| 4661 | + if batched: |
| 4662 | + td = env.reset(TensorDict(batch_size=[3])) |
| 4663 | + env.check_env_specs(break_when_any_done="both", tensordict=td) |
| 4664 | + else: |
| 4665 | + env.check_env_specs(break_when_any_done="both") |
| 4666 | + |
| 4667 | + |
4529 | 4668 | if __name__ == "__main__":
|
4530 | 4669 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
4531 | 4670 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments