Skip to content

Commit 183a7c3

Browse files
committed
Allow users to customize dataloader
ghstack-source-id: c9999c5 Pull Request resolved: #836
1 parent ab94a99 commit 183a7c3

File tree

9 files changed

+159
-90
lines changed

9 files changed

+159
-90
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
9-
from torchtitan.datasets.tokenizer import build_tokenizer
8+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
109

1110

1211
class TestDatasetCheckpointing:
@@ -41,13 +40,12 @@ def test_c4_resumption(self):
4140
def _build_dataloader(
4241
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
4342
):
44-
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
45-
return build_hf_data_loader(
43+
return build_hf_dataloader(
4644
dataset_name=dataset_name,
4745
dataset_path=dataset_path,
48-
tokenizer=tokenizer,
46+
tokenizer_path="./tests/assets/test_tiktoken.model",
4947
batch_size=1,
5048
seq_len=1024,
51-
world_size=4,
52-
rank=0,
49+
dp_world_size=4,
50+
dp_rank=0,
5351
)

tests/unit_tests/test_train_spec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
from torchtitan.config_manager import JobConfig
13+
from torchtitan.datasets import build_hf_dataloader
1314
from torchtitan.models.llama import parallelize_llama, pipeline_llama
1415
from torchtitan.optimizer import (
1516
build_lr_schedulers,
@@ -60,6 +61,7 @@ def test_register_train_spec(self):
6061
pipelining_fn=pipeline_llama,
6162
build_optimizers_fn=build_optimizers,
6263
build_lr_schedulers_fn=build_lr_schedulers,
64+
build_dataloader_fn=build_hf_dataloader,
6365
)
6466
register_train_spec(spec)
6567
new_spec = get_train_spec("fake")
@@ -78,6 +80,7 @@ def test_optim_hook(self):
7880
pipelining_fn=pipeline_llama,
7981
build_optimizers_fn=fake_build_optimizers,
8082
build_lr_schedulers_fn=build_lr_schedulers,
83+
build_dataloader_fn=build_hf_dataloader,
8184
)
8285
register_train_spec(spec)
8386
new_spec = get_train_spec("fake2")

torchtitan/dataloader.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
import pickle
10+
from abc import ABC, abstractmethod
11+
from dataclasses import dataclass
12+
from typing import Any, Callable, TypeAlias
13+
14+
from torch.distributed.checkpoint.stateful import Stateful
15+
from torch.utils.data import IterableDataset
16+
from torchdata.stateful_dataloader import StatefulDataLoader
17+
18+
from torchtitan.datasets.tokenizer import Tokenizer
19+
20+
21+
class BaseDataLoader(Stateful, ABC):
22+
"""Base class for all dataloaders.
23+
24+
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
25+
``state_dict()`` and ``load_state_dict()``.
26+
"""
27+
28+
@abstractmethod
29+
def __iter__(self):
30+
...
31+
32+
33+
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
34+
"""Dataloader that is aware of distributed data parallelism.
35+
36+
This dataloader is used to load data in a distributed data parallel fashion. It also
37+
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
38+
methods such as ``__iter__``.
39+
40+
Args:
41+
dataset (IterableDataset): The dataset to iterate over.
42+
tokenizer (Tokenizer): The tokenizer to use to tokenize the dataset.
43+
dp_rank: Data parallelism rank for this dataloader.
44+
dp_world_size: The world size of the data parallelism.
45+
batch_size: The batch size to use for each iteration.
46+
"""
47+
48+
tokenizer: Tokenizer
49+
dp_rank: int
50+
dp_world_size: int
51+
batch_size: int
52+
53+
def __init__(
54+
self,
55+
dataset: IterableDataset,
56+
tokenizer: Tokenizer,
57+
dp_rank: int,
58+
dp_world_size: int,
59+
batch_size: int,
60+
):
61+
self.dp_world_size = dp_world_size
62+
self.dp_rank = dp_rank
63+
self.batch_size = batch_size
64+
self.tokenizer = tokenizer
65+
super().__init__(dataset, batch_size)
66+
self._rank_id = f"dp_rank_{dp_rank}"
67+
68+
def state_dict(self) -> dict[str, Any]:
69+
# Store state only for dp rank to avoid replicating the same state across other dimensions.
70+
return {
71+
# We don't have to use pickle as DCP will serialize the state_dict. However,
72+
# we have to keep this for backward compatibility.
73+
self._rank_id: pickle.dumps(super().state_dict()),
74+
"world_size": self.dp_world_size,
75+
}
76+
77+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
78+
# State being empty is valid.
79+
if not state_dict:
80+
return
81+
82+
if self._rank_id not in state_dict:
83+
logger.warning(
84+
f"DataLoader state is empty for dp rank {self.dp_rank}, "
85+
"expected key {self._rank_id}"
86+
)
87+
return
88+
89+
assert self.dp_world_size == state_dict["world_size"], (
90+
"dp_degree is inconsistent before and after checkpoint, "
91+
"dataloader resharding is not supported yet."
92+
)
93+
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
94+
# keep this for backward compatibility.
95+
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
96+
97+
98+
DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader]

torchtitan/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
7+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
88
from torchtitan.datasets.tokenizer import build_tokenizer
99

1010
__all__ = [
11-
"build_hf_data_loader",
11+
"build_hf_dataloader",
1212
"build_tokenizer",
1313
]

torchtitan/datasets/hf_datasets.py

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import pickle
87
from dataclasses import dataclass
9-
from typing import Any, Callable, Dict, List, Optional
8+
from typing import Any, Callable, Optional
109

1110
import torch
11+
12+
from datasets import Dataset, load_dataset
13+
from datasets.distributed import split_dataset_by_node
1214
from torch.distributed.checkpoint.stateful import Stateful
1315
from torch.utils.data import IterableDataset
14-
from torchdata.stateful_dataloader import StatefulDataLoader
1516

16-
from torchtitan.datasets.tokenizer import Tokenizer
17-
from torchtitan.logging import logger
17+
from torchtitan.dataloader import ParallelAwareDataloader
1818

19-
from datasets import Dataset, load_dataset
20-
from datasets.distributed import split_dataset_by_node
19+
from torchtitan.datasets.tokenizer import build_tokenizer, Tokenizer
20+
from torchtitan.logging import logger
2121

2222

2323
def _load_c4_dataset(dataset_path: str):
2424
"""Load C4 dataset with default configuration."""
2525
return load_dataset(dataset_path, name="en", split="train", streaming=True)
2626

2727

28-
def _process_c4_text(sample: Dict[str, Any]) -> str:
28+
def _process_c4_text(sample: dict[str, Any]) -> str:
2929
"""Process C4 dataset sample text."""
3030
return sample["text"]
3131

@@ -75,8 +75,8 @@ def __init__(
7575
dataset_path: Optional[str],
7676
tokenizer: Tokenizer,
7777
seq_len: int = 2048,
78-
world_size: int = 1,
79-
rank: int = 0,
78+
dp_rank: int = 0,
79+
dp_world_size: int = 1,
8080
infinite: bool = False,
8181
) -> None:
8282
# Force lowercase for consistent comparison
@@ -88,15 +88,15 @@ def __init__(
8888
ds = dataset_loader(path)
8989

9090
self.dataset_name = dataset_name
91-
self._data = split_dataset_by_node(ds, rank, world_size)
91+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
9292
self._tokenizer = tokenizer
9393
self.seq_len = seq_len
9494
self.infinite = infinite
9595
self._text_processor = text_processor
9696

9797
# Variables for checkpointing
9898
self._sample_idx = 0
99-
self._all_tokens: List[int] = []
99+
self._all_tokens: list[int] = []
100100

101101
def _get_data_iter(self):
102102
if self._sample_idx == 0:
@@ -142,56 +142,33 @@ def state_dict(self):
142142
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
143143

144144

145-
class DPAwareDataLoader(StatefulDataLoader, Stateful):
146-
"""
147-
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
148-
"""
149-
150-
def __init__(
151-
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
152-
):
153-
super().__init__(hf_ds, batch_size)
154-
self._dp_rank = dp_rank
155-
self._rank_id = f"dp_rank_{dp_rank}"
156-
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
157-
# raise error if dp_word_size does not match.
158-
self._world_size = world_size
159-
160-
def state_dict(self) -> Dict[str, Any]:
161-
# Store state only for dp rank to avoid replicating the same state across other dimensions
162-
return {
163-
self._rank_id: pickle.dumps(super().state_dict()),
164-
"world_size": self._world_size,
165-
}
166-
167-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
168-
# State being empty is valid
169-
if not state_dict:
170-
return
171-
172-
if self._rank_id not in state_dict:
173-
logger.warning(
174-
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
175-
)
176-
return
177-
assert (
178-
self._world_size == state_dict["world_size"]
179-
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
180-
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
181-
182-
183-
def build_hf_data_loader(
145+
def build_hf_dataloader(
184146
dataset_name: str,
185147
dataset_path: Optional[str],
186-
tokenizer: Tokenizer,
148+
tokenizer_path: str,
187149
batch_size: int,
188150
seq_len: int,
189-
world_size: int,
190-
rank: int,
151+
dp_rank: int,
152+
dp_world_size: int,
191153
infinite: bool = True,
192-
):
154+
) -> ParallelAwareDataloader:
193155
"""Build a data loader for HuggingFace datasets."""
156+
tokenizer = build_tokenizer("tiktoken", tokenizer_path)
157+
194158
hf_ds = HuggingFaceDataset(
195-
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
159+
dataset_name=dataset_name,
160+
dataset_path=dataset_path,
161+
tokenizer=tokenizer,
162+
seq_len=seq_len,
163+
dp_rank=dp_rank,
164+
dp_world_size=dp_world_size,
165+
infinite=infinite,
166+
)
167+
168+
return ParallelAwareDataloader(
169+
dataset=hf_ds,
170+
tokenizer=tokenizer,
171+
dp_rank=dp_rank,
172+
dp_world_size=dp_world_size,
173+
batch_size=batch_size,
196174
)
197-
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)

torchtitan/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,3 @@
88
# Import the built-in models here so that the corresponding register_model_spec()
99
# will be called.
1010
import torchtitan.models.llama # noqa: F401
11-
12-
13-
model_name_to_tokenizer = {"llama3": "tiktoken"}

torchtitan/models/llama/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
from torchtitan.datasets import build_hf_dataloader
910
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
1011
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
1112
from torchtitan.train_spec import register_train_spec, TrainSpec
@@ -65,5 +66,6 @@
6566
pipelining_fn=pipeline_llama,
6667
build_optimizers_fn=build_optimizers,
6768
build_lr_schedulers_fn=build_lr_schedulers,
69+
build_dataloader_fn=build_hf_dataloader,
6870
)
6971
)

torchtitan/train_spec.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
109
from dataclasses import dataclass
11-
from typing import Callable, Dict, Protocol, Type, TypeAlias
10+
from typing import Callable, Protocol, Type, TypeAlias
1211

1312
import torch.nn as nn
1413
from torch.distributed.pipelining.schedules import _PipelineSchedule
1514

1615
from torchtitan.config_manager import JobConfig
16+
from torchtitan.dataloader import DataLoaderBuilder
1717
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
1818

1919

@@ -53,15 +53,14 @@ def from_model_args(args: BaseModelArgs) -> nn.Module:
5353
class TrainSpec:
5454
name: str
5555
cls: Type[nn.Module]
56-
config: Dict[str, BaseModelArgs]
56+
config: dict[str, BaseModelArgs]
5757
parallelize_fn: Callable[[nn.Module], None]
5858
pipelining_fn: Callable[
5959
[nn.Module], tuple[_PipelineSchedule, list[nn.Module], bool, bool]
6060
]
6161
build_optimizers_fn: OptimizersBuilder
6262
build_lr_schedulers_fn: LRSchedulersBuilder
63-
64-
# TODO: Add a ``build_dataloader_fn``
63+
build_dataloader_fn: DataLoaderBuilder
6564

6665
# TODO: Add a FQN convert fn to allow users to load checkpoints from
6766
# HuggingFace or other sources that have different FQN conventions.

0 commit comments

Comments
 (0)