Skip to content

Commit 7dd2915

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents a839bc6 + 73c7b0a commit 7dd2915

27 files changed

+1101
-918
lines changed

Diff for: docs/source/reference/data.rst

+3
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ efficient sampling.
11331133
get_dataloader
11341134
ConstantKLController
11351135
AdaptiveKLController
1136+
LLMData
1137+
LLMInput
1138+
LLMOutput
11361139

11371140

11381141
Utils

Diff for: examples/rlhf/data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
1+
from torchrl.data.llm.prompt import get_prompt_dataloader_tldr
22

33
__all__ = ["get_prompt_dataloader_tldr"]

Diff for: examples/rlhf/models/reward.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tensordict.nn import TensorDictModule
99
from torchrl._utils import logger as torchrl_logger
1010

11-
from torchrl.modules.models.rlhf import GPT2RewardModel
11+
from torchrl.modules.models.llm import GPT2RewardModel
1212

1313

1414
def init_reward_model(

Diff for: examples/rlhf/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torch.optim.lr_scheduler import CosineAnnealingLR
1818
from torchrl._utils import logger as torchrl_logger
1919

20-
from torchrl.data.rlhf.dataset import get_dataloader
21-
from torchrl.data.rlhf.prompt import PromptData
20+
from torchrl.data.llm.dataset import get_dataloader
21+
from torchrl.data.llm.prompt import PromptData
2222
from utils import get_file_logger, resolve_name_or_path, setup
2323

2424

Diff for: examples/rlhf/train_reward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from models.reward import init_reward_model
1010
from torch.optim.lr_scheduler import CosineAnnealingLR
1111
from torchrl._utils import logger as torchrl_logger
12-
from torchrl.data.rlhf.dataset import get_dataloader
13-
from torchrl.data.rlhf.reward import PairwiseDataset
12+
from torchrl.data.llm.dataset import get_dataloader
13+
from torchrl.data.llm.reward import PairwiseDataset
1414
from utils import get_file_logger, resolve_name_or_path, setup
1515

1616

Diff for: examples/rlhf/train_rlhf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import hydra
77
import torch
88
from models.actor_critic import init_actor_critic
9-
from torchrl.data.rlhf.utils import AdaptiveKLController, RolloutFromModel
9+
from torchrl.data.llm.utils import AdaptiveKLController, RolloutFromModel
1010

1111
from torchrl.record.loggers import get_logger
1212

Diff for: examples/rlhf/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
TensorDictReplayBuffer,
2323
TensorStorage,
2424
)
25+
from torchrl.data.llm.dataset import get_dataloader
26+
from torchrl.data.llm.prompt import PromptData
2527
from torchrl.data.replay_buffers import SamplerWithoutReplacement
26-
from torchrl.data.rlhf.dataset import get_dataloader
27-
from torchrl.data.rlhf.prompt import PromptData
2828
from torchrl.objectives import ClipPPOLoss
2929
from torchrl.objectives.value import GAE
3030

Diff for: test/assets/generate.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Script used to generate the mini datasets."""
77
import multiprocessing as mp
8+
import pathlib
89

910
try:
1011
mp.set_start_method("spawn")
@@ -14,8 +15,8 @@
1415

1516
from datasets import Dataset, DatasetDict, load_dataset
1617

17-
from torchrl.data.rlhf.dataset import get_dataloader
18-
from torchrl.data.rlhf.prompt import PromptData
18+
from torchrl.data.llm.dataset import get_dataloader
19+
from torchrl.data.llm.prompt import PromptData
1920

2021

2122
def generate_small_dataset(comparison=True):
@@ -42,7 +43,7 @@ def get_minibatch():
4243
batch_size=16,
4344
block_size=33,
4445
tensorclass_type=PromptData,
45-
dataset_name="../datasets_mini/openai_summarize_tldr",
46+
dataset_name=f"{pathlib.Path(__file__).parent}/../datasets_mini/openai_summarize_tldr",
4647
device="cpu",
4748
num_workers=2,
4849
infinite=False,
@@ -52,9 +53,12 @@ def get_minibatch():
5253
root_dir=tmpdir,
5354
)
5455
for data in dl:
55-
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
56+
data = data.clone().memmap_(
57+
f"{pathlib.Path(__file__).parent}/../datasets_mini/tldr_batch/"
58+
)
5659
break
5760

5861

5962
if __name__ == "__main__":
63+
generate_small_dataset(False)
6064
get_minibatch()

Diff for: test/assets/tldr_batch.zip

2 Bytes
Binary file not shown.

Diff for: test/test_actors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17-
from torchrl.data.rlhf.dataset import _has_transformers
17+
from torchrl.data.llm.dataset import _has_transformers
1818
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,

Diff for: test/test_env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
6262
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
6363
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
64-
from torchrl.envs.transforms.rlhf import as_padded_tensor
64+
from torchrl.envs.transforms.llm import as_padded_tensor
6565
from torchrl.envs.transforms.transforms import (
6666
AutoResetEnv,
6767
AutoResetTransform,

Diff for: test/test_rlhf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
TensorDictBase,
2222
)
2323
from tensordict.nn import TensorDictModule
24-
from torchrl.data.rlhf import TensorDictTokenizer
25-
from torchrl.data.rlhf.dataset import (
24+
from torchrl.data.llm import TensorDictTokenizer
25+
from torchrl.data.llm.dataset import (
2626
_has_datasets,
2727
_has_transformers,
2828
get_dataloader,
2929
TokenizedDatasetLoader,
3030
)
31-
from torchrl.data.rlhf.prompt import PromptData, PromptTensorDictTokenizer
32-
from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook
33-
from torchrl.data.rlhf.utils import RolloutFromModel
34-
from torchrl.modules.models.rlhf import GPT2RewardModel
31+
from torchrl.data.llm.prompt import PromptData, PromptTensorDictTokenizer
32+
from torchrl.data.llm.reward import PairwiseDataset, pre_tokenization_hook
33+
from torchrl.data.llm.utils import RolloutFromModel
34+
from torchrl.modules.models.llm import GPT2RewardModel
3535

3636
if os.getenv("PYTORCH_TEST_FBCODE"):
3737
from pytorch.rl.test._utils_internal import get_default_devices

Diff for: test/test_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@
117117
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
118118
from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents
119119
from torchrl.envs.transforms import VecNorm
120+
from torchrl.envs.transforms.llm import KLRewardTransform
120121
from torchrl.envs.transforms.r3m import _R3MNet
121-
from torchrl.envs.transforms.rlhf import KLRewardTransform
122122
from torchrl.envs.transforms.transforms import (
123123
_has_tv,
124124
ActionDiscretizer,

Diff for: torchrl/data/__init__.py

+64-58
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,22 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .llm import (
7+
AdaptiveKLController,
8+
ConstantKLController,
9+
create_infinite_iterator,
10+
get_dataloader,
11+
LLMData,
12+
LLMInput,
13+
LLMOutput,
14+
PairwiseDataset,
15+
PromptData,
16+
PromptTensorDictTokenizer,
17+
RewardData,
18+
RolloutFromModel,
19+
TensorDictTokenizer,
20+
TokenizedDatasetLoader,
21+
)
622
from .map import (
723
BinaryToDecimal,
824
HashToInt,
@@ -56,19 +72,6 @@
5672
Writer,
5773
WriterEnsemble,
5874
)
59-
from .rlhf import (
60-
AdaptiveKLController,
61-
ConstantKLController,
62-
create_infinite_iterator,
63-
get_dataloader,
64-
PairwiseDataset,
65-
PromptData,
66-
PromptTensorDictTokenizer,
67-
RewardData,
68-
RolloutFromModel,
69-
TensorDictTokenizer,
70-
TokenizedDatasetLoader,
71-
)
7275
from .tensor_specs import (
7376
Binary,
7477
BinaryDiscreteTensorSpec,
@@ -103,96 +106,99 @@
103106
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
104107

105108
__all__ = [
109+
"AdaptiveKLController",
110+
"Binary",
111+
"BinaryDiscreteTensorSpec",
106112
"BinaryToDecimal",
107-
"HashToInt",
108-
"MCTSForest",
109-
"QueryModule",
110-
"RandomProjectionHash",
111-
"SipHash",
112-
"TensorDictMap",
113-
"TensorMap",
114-
"Tree",
115-
"MultiStep",
113+
"Bounded",
114+
"BoundedContinuous",
115+
"BoundedTensorSpec",
116+
"Categorical",
117+
"Choice",
118+
"Composite",
119+
"CompositeSpec",
120+
"ConstantKLController",
121+
"DEVICE_TYPING",
122+
"DiscreteTensorSpec",
116123
"Flat2TED",
117124
"FlatStorageCheckpointer",
118125
"H5Combine",
119126
"H5Split",
120127
"H5StorageCheckpointer",
128+
"HashToInt",
121129
"ImmutableDatasetWriter",
130+
"LLMData",
131+
"LLMInput",
132+
"LLMOutput",
122133
"LazyMemmapStorage",
123134
"LazyStackStorage",
135+
"LazyStackedCompositeSpec",
136+
"LazyStackedTensorSpec",
124137
"LazyTensorStorage",
125138
"ListStorage",
126139
"ListStorageCheckpointer",
140+
"MCTSForest",
141+
"MultiCategorical",
142+
"MultiDiscreteTensorSpec",
143+
"MultiOneHot",
144+
"MultiOneHotDiscreteTensorSpec",
145+
"MultiStep",
127146
"Nested2TED",
128147
"NestedStorageCheckpointer",
148+
"NonTensor",
149+
"NonTensorSpec",
150+
"OneHot",
151+
"OneHotDiscreteTensorSpec",
152+
"PairwiseDataset",
129153
"PrioritizedReplayBuffer",
130154
"PrioritizedSampler",
131155
"PrioritizedSliceSampler",
156+
"PromptData",
157+
"PromptTensorDictTokenizer",
158+
"QueryModule",
159+
"RandomProjectionHash",
132160
"RandomSampler",
133161
"RemoteTensorDictReplayBuffer",
134162
"ReplayBuffer",
135163
"ReplayBufferEnsemble",
164+
"RewardData",
165+
"RolloutFromModel",
136166
"RoundRobinWriter",
137167
"SamplerEnsemble",
138168
"SamplerWithoutReplacement",
169+
"SipHash",
139170
"SliceSampler",
140171
"SliceSamplerWithoutReplacement",
172+
"Stacked",
173+
"StackedComposite",
141174
"Storage",
142175
"StorageCheckpointerBase",
143176
"StorageEnsemble",
144177
"StorageEnsembleCheckpointer",
145178
"TED2Flat",
146179
"TED2Nested",
180+
"TensorDictMap",
147181
"TensorDictMaxValueWriter",
148182
"TensorDictPrioritizedReplayBuffer",
149183
"TensorDictReplayBuffer",
150184
"TensorDictRoundRobinWriter",
185+
"TensorDictTokenizer",
186+
"TensorMap",
187+
"TensorSpec",
151188
"TensorStorage",
152189
"TensorStorageCheckpointer",
153-
"Writer",
154-
"WriterEnsemble",
155-
"AdaptiveKLController",
156-
"ConstantKLController",
157-
"create_infinite_iterator",
158-
"get_dataloader",
159-
"PairwiseDataset",
160-
"PromptData",
161-
"PromptTensorDictTokenizer",
162-
"RewardData",
163-
"RolloutFromModel",
164-
"TensorDictTokenizer",
165190
"TokenizedDatasetLoader",
166-
"Binary",
167-
"BinaryDiscreteTensorSpec",
168-
"Bounded",
169-
"BoundedContinuous",
170-
"BoundedTensorSpec",
171-
"Categorical",
172-
"Choice",
173-
"Composite",
174-
"CompositeSpec",
175-
"DEVICE_TYPING",
176-
"DiscreteTensorSpec",
177-
"LazyStackedCompositeSpec",
178-
"LazyStackedTensorSpec",
179-
"MultiCategorical",
180-
"MultiDiscreteTensorSpec",
181-
"MultiOneHot",
182-
"MultiOneHotDiscreteTensorSpec",
183-
"NonTensor",
184-
"NonTensorSpec",
185-
"OneHot",
186-
"OneHotDiscreteTensorSpec",
187-
"Stacked",
188-
"StackedComposite",
189-
"TensorSpec",
191+
"Tree",
190192
"Unbounded",
191193
"UnboundedContinuous",
192194
"UnboundedContinuousTensorSpec",
193195
"UnboundedDiscrete",
194196
"UnboundedDiscreteTensorSpec",
197+
"Writer",
198+
"WriterEnsemble",
195199
"check_no_exclusive_keys",
196200
"consolidate_spec",
197201
"contains_lazy_spec",
202+
"create_infinite_iterator",
203+
"get_dataloader",
198204
]

Diff for: torchrl/data/rlhf/__init__.py renamed to torchrl/data/llm/__init__.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
14+
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
1515

1616
__all__ = [
17-
"create_infinite_iterator",
18-
"get_dataloader",
19-
"TensorDictTokenizer",
20-
"TokenizedDatasetLoader",
17+
"AdaptiveKLController",
18+
"ConstantKLController",
19+
"LLMData",
20+
"LLMInput",
21+
"LLMOutput",
22+
"PairwiseDataset",
2123
"PromptData",
2224
"PromptTensorDictTokenizer",
23-
"PairwiseDataset",
2425
"RewardData",
25-
"AdaptiveKLController",
26-
"ConstantKLController",
2726
"RolloutFromModel",
27+
"TensorDictTokenizer",
28+
"TokenizedDatasetLoader",
29+
"create_infinite_iterator",
30+
"get_dataloader",
2831
]

0 commit comments

Comments
 (0)