Skip to content

Commit 249733f

Browse files
committed
Update
[ghstack-poisoned]
2 parents bb7d2f9 + 7dd2915 commit 249733f

27 files changed

+1056
-873
lines changed

Diff for: docs/source/reference/data.rst

+3
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ efficient sampling.
11331133
get_dataloader
11341134
ConstantKLController
11351135
AdaptiveKLController
1136+
LLMData
1137+
LLMInput
1138+
LLMOutput
11361139

11371140

11381141
Utils

Diff for: examples/rlhf/data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
1+
from torchrl.data.llm.prompt import get_prompt_dataloader_tldr
22

33
__all__ = ["get_prompt_dataloader_tldr"]

Diff for: examples/rlhf/models/reward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tensordict.nn import TensorDictModule
99
from torchrl._utils import logger as torchrl_logger
1010

11-
from torchrl.modules.models.rlhf import GPT2RewardModel
11+
from torchrl.modules.models.llm import GPT2RewardModel
1212

1313

1414
def init_reward_model(

Diff for: examples/rlhf/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torch.optim.lr_scheduler import CosineAnnealingLR
1818
from torchrl._utils import logger as torchrl_logger
1919

20-
from torchrl.data.rlhf.dataset import get_dataloader
21-
from torchrl.data.rlhf.prompt import PromptData
20+
from torchrl.data.llm.dataset import get_dataloader
21+
from torchrl.data.llm.prompt import PromptData
2222
from utils import get_file_logger, resolve_name_or_path, setup
2323

2424

Diff for: examples/rlhf/train_reward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from models.reward import init_reward_model
1010
from torch.optim.lr_scheduler import CosineAnnealingLR
1111
from torchrl._utils import logger as torchrl_logger
12-
from torchrl.data.rlhf.dataset import get_dataloader
13-
from torchrl.data.rlhf.reward import PairwiseDataset
12+
from torchrl.data.llm.dataset import get_dataloader
13+
from torchrl.data.llm.reward import PairwiseDataset
1414
from utils import get_file_logger, resolve_name_or_path, setup
1515

1616

Diff for: examples/rlhf/train_rlhf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import hydra
77
import torch
88
from models.actor_critic import init_actor_critic
9-
from torchrl.data.rlhf.utils import AdaptiveKLController, RolloutFromModel
9+
from torchrl.data.llm.utils import AdaptiveKLController, RolloutFromModel
1010

1111
from torchrl.record.loggers import get_logger
1212

Diff for: examples/rlhf/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
TensorDictReplayBuffer,
2323
TensorStorage,
2424
)
25+
from torchrl.data.llm.dataset import get_dataloader
26+
from torchrl.data.llm.prompt import PromptData
2527
from torchrl.data.replay_buffers import SamplerWithoutReplacement
26-
from torchrl.data.rlhf.dataset import get_dataloader
27-
from torchrl.data.rlhf.prompt import PromptData
2828
from torchrl.objectives import ClipPPOLoss
2929
from torchrl.objectives.value import GAE
3030

Diff for: test/assets/generate.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Script used to generate the mini datasets."""
77
import multiprocessing as mp
8+
import pathlib
89

910
try:
1011
mp.set_start_method("spawn")
@@ -14,8 +15,8 @@
1415

1516
from datasets import Dataset, DatasetDict, load_dataset
1617

17-
from torchrl.data.rlhf.dataset import get_dataloader
18-
from torchrl.data.rlhf.prompt import PromptData
18+
from torchrl.data.llm.dataset import get_dataloader
19+
from torchrl.data.llm.prompt import PromptData
1920

2021

2122
def generate_small_dataset(comparison=True):
@@ -42,7 +43,7 @@ def get_minibatch():
4243
batch_size=16,
4344
block_size=33,
4445
tensorclass_type=PromptData,
45-
dataset_name="../datasets_mini/openai_summarize_tldr",
46+
dataset_name=f"{pathlib.Path(__file__).parent}/../datasets_mini/openai_summarize_tldr",
4647
device="cpu",
4748
num_workers=2,
4849
infinite=False,
@@ -52,9 +53,12 @@ def get_minibatch():
5253
root_dir=tmpdir,
5354
)
5455
for data in dl:
55-
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
56+
data = data.clone().memmap_(
57+
f"{pathlib.Path(__file__).parent}/../datasets_mini/tldr_batch/"
58+
)
5659
break
5760

5861

5962
if __name__ == "__main__":
63+
generate_small_dataset(False)
6064
get_minibatch()

Diff for: test/assets/tldr_batch.zip

2 Bytes
Binary file not shown.

Diff for: test/test_actors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17-
from torchrl.data.rlhf.dataset import _has_transformers
17+
from torchrl.data.llm.dataset import _has_transformers
1818
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,

Diff for: test/test_env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
6262
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
6363
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
64-
from torchrl.envs.transforms.rlhf import as_padded_tensor
64+
from torchrl.envs.transforms.llm import as_padded_tensor
6565
from torchrl.envs.transforms.transforms import (
6666
AutoResetEnv,
6767
AutoResetTransform,

Diff for: test/test_rlhf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
TensorDictBase,
2222
)
2323
from tensordict.nn import TensorDictModule
24-
from torchrl.data.rlhf import TensorDictTokenizer
25-
from torchrl.data.rlhf.dataset import (
24+
from torchrl.data.llm import TensorDictTokenizer
25+
from torchrl.data.llm.dataset import (
2626
_has_datasets,
2727
_has_transformers,
2828
get_dataloader,
2929
TokenizedDatasetLoader,
3030
)
31-
from torchrl.data.rlhf.prompt import PromptData, PromptTensorDictTokenizer
32-
from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook
33-
from torchrl.data.rlhf.utils import RolloutFromModel
34-
from torchrl.modules.models.rlhf import GPT2RewardModel
31+
from torchrl.data.llm.prompt import PromptData, PromptTensorDictTokenizer
32+
from torchrl.data.llm.reward import PairwiseDataset, pre_tokenization_hook
33+
from torchrl.data.llm.utils import RolloutFromModel
34+
from torchrl.modules.models.llm import GPT2RewardModel
3535

3636
if os.getenv("PYTORCH_TEST_FBCODE"):
3737
from pytorch.rl.test._utils_internal import get_default_devices

Diff for: test/test_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@
117117
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
118118
from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents
119119
from torchrl.envs.transforms import VecNorm
120+
from torchrl.envs.transforms.llm import KLRewardTransform
120121
from torchrl.envs.transforms.r3m import _R3MNet
121-
from torchrl.envs.transforms.rlhf import KLRewardTransform
122122
from torchrl.envs.transforms.transforms import (
123123
_has_tv,
124124
ActionDiscretizer,

Diff for: torchrl/data/__init__.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .llm import (
7+
AdaptiveKLController,
8+
ConstantKLController,
9+
create_infinite_iterator,
10+
get_dataloader,
11+
LLMData,
12+
LLMInput,
13+
LLMOutput,
14+
PairwiseDataset,
15+
PromptData,
16+
PromptTensorDictTokenizer,
17+
RewardData,
18+
RolloutFromModel,
19+
TensorDictTokenizer,
20+
TokenizedDatasetLoader,
21+
)
622
from .map import (
723
BinaryToDecimal,
824
HashToInt,
@@ -56,19 +72,6 @@
5672
Writer,
5773
WriterEnsemble,
5874
)
59-
from .rlhf import (
60-
AdaptiveKLController,
61-
ConstantKLController,
62-
create_infinite_iterator,
63-
get_dataloader,
64-
PairwiseDataset,
65-
PromptData,
66-
PromptTensorDictTokenizer,
67-
RewardData,
68-
RolloutFromModel,
69-
TensorDictTokenizer,
70-
TokenizedDatasetLoader,
71-
)
7275
from .tensor_specs import (
7376
Binary,
7477
BinaryDiscreteTensorSpec,
@@ -125,6 +128,9 @@
125128
"H5StorageCheckpointer",
126129
"HashToInt",
127130
"ImmutableDatasetWriter",
131+
"LLMData",
132+
"LLMInput",
133+
"LLMOutput",
128134
"LazyMemmapStorage",
129135
"LazyStackStorage",
130136
"LazyStackedCompositeSpec",

Diff for: torchrl/data/rlhf/__init__.py renamed to torchrl/data/llm/__init__.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
14+
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
1515

1616
__all__ = [
17-
"create_infinite_iterator",
18-
"get_dataloader",
19-
"TensorDictTokenizer",
20-
"TokenizedDatasetLoader",
17+
"AdaptiveKLController",
18+
"ConstantKLController",
19+
"LLMData",
20+
"LLMInput",
21+
"LLMOutput",
22+
"PairwiseDataset",
2123
"PromptData",
2224
"PromptTensorDictTokenizer",
23-
"PairwiseDataset",
2425
"RewardData",
25-
"AdaptiveKLController",
26-
"ConstantKLController",
2726
"RolloutFromModel",
27+
"TensorDictTokenizer",
28+
"TokenizedDatasetLoader",
29+
"create_infinite_iterator",
30+
"get_dataloader",
2831
]

Diff for: torchrl/data/rlhf/dataset.py renamed to torchrl/data/llm/dataset.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TokenizedDatasetLoader:
3131
max_length (int): the maximum sequence length.
3232
dataset_name (str): the name of the dataset.
3333
tokenizer_fn (callable): the tokeinizing method constructor, such as
34-
:class:`torchrl.data.rlhf.TensorDictTokenizer`. When called,
34+
:class:`torchrl.data.llm.TensorDictTokenizer`. When called,
3535
it should return a :class:`tensordict.TensorDict` instance
3636
or a dictionary-like structure with the tokenized data.
3737
pre_tokenization_hook (callable, optional): called on
@@ -62,8 +62,8 @@ class TokenizedDatasetLoader:
6262
The dataset will be stored in ``<root_dir>/<split>/<max_length>/``.
6363
6464
Examples:
65-
>>> from torchrl.data.rlhf import TensorDictTokenizer
66-
>>> from torchrl.data.rlhf.reward import pre_tokenization_hook
65+
>>> from torchrl.data.llm import TensorDictTokenizer
66+
>>> from torchrl.data.llm.reward import pre_tokenization_hook
6767
>>> split = "train"
6868
>>> max_length = 550
6969
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
@@ -359,7 +359,7 @@ def get_dataloader(
359359
Defaults to ``max(os.cpu_count() // 2, 1)``.
360360
361361
Examples:
362-
>>> from torchrl.data.rlhf.reward import PairwiseDataset
362+
>>> from torchrl.data.llm.reward import PairwiseDataset
363363
>>> dataloader = get_dataloader(
364364
... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
365365
>>> for d in dataloader:

Diff for: torchrl/data/rlhf/prompt.py renamed to torchrl/data/llm/prompt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tensordict import tensorclass, TensorDict
99

10-
from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader
10+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
1111

1212
DEFAULT_DATASET = "CarperAI/openai_summarize_tldr"
1313

Diff for: torchrl/data/rlhf/reward.py renamed to torchrl/data/llm/reward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from tensordict import tensorclass
11-
from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader
11+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
1212

1313
DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons"
1414
_has_datasets = importlib.util.find_spec("datasets") is not None

0 commit comments

Comments
 (0)