Skip to content

Commit 57e7e2d

Browse files
committed
Allow users to customize dataloader
ghstack-source-id: 55f5b9a Pull Request resolved: #836
1 parent 36c6d2f commit 57e7e2d

File tree

8 files changed

+170
-87
lines changed

8 files changed

+170
-87
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
8+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
99
from torchtitan.datasets.tokenizer import build_tokenizer
1010

1111

@@ -42,7 +42,7 @@ def _build_dataloader(
4242
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
4343
):
4444
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
45-
return build_hf_data_loader(
45+
return build_hf_dataloader(
4646
dataset_name=dataset_name,
4747
dataset_path=dataset_path,
4848
tokenizer=tokenizer,

torchtitan/dataloader.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
import pickle
10+
from abc import ABC, abstractmethod
11+
from dataclasses import dataclass
12+
from typing import Any, Optional, Protocol
13+
14+
from torch.distributed.checkpoint.stateful import Stateful
15+
from torch.utils.data import IterableDataset
16+
from torchdata.stateful_dataloader import StatefulDataLoader
17+
18+
from torchtitan.datasets.tokenizer import Tokenizer
19+
20+
21+
@dataclass
22+
class BaseDataLoader(Stateful, ABC):
23+
"""Base class for all dataloaders.
24+
25+
This is used to enforce that all dataloaders have the two methods from ``Stateful``,
26+
``state_dict()`` and ``load_state_dict()``.
27+
"""
28+
29+
tokenizer: Tokenizer
30+
dp_rank: int
31+
dp_world_size: int
32+
batch_size: int
33+
34+
@abstractmethod
35+
def __iter__(self):
36+
...
37+
38+
39+
class DPDataLoader(StatefulDataLoader, BaseDataLoader):
40+
"""Dataloader that is aware of data parallelism
41+
42+
This dataloader is used to load data in a distributed fashion. It also utilizes
43+
``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
44+
methods such as ``__iter__``.
45+
"""
46+
47+
def __init__(
48+
self,
49+
dataset: IterableDataset,
50+
tokenizer: Tokenizer,
51+
dp_rank: int,
52+
dp_world_size: int,
53+
batch_size: int,
54+
):
55+
BaseDataLoader.__init__(
56+
self,
57+
tokenizer=tokenizer,
58+
dp_rank=dp_rank,
59+
dp_world_size=dp_world_size,
60+
batch_size=batch_size,
61+
)
62+
StatefulDataLoader.__init__(self, dataset, batch_size)
63+
self._rank_id = f"dp_rank_{dp_rank}"
64+
65+
def state_dict(self) -> dict[str, Any]:
66+
# Store state only for dp rank to avoid replicating the same state across other dimensions.
67+
return {
68+
# We don't have to use pickle as DCP will serialize the state_dict. However,
69+
# we have to keep this for backward compatibility.
70+
self._rank_id: pickle.dumps(StatefulDataLoader(self).state_dict()),
71+
"world_size": self.dp_world_size,
72+
}
73+
74+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
75+
# State being empty is valid.
76+
if not state_dict:
77+
return
78+
79+
if self._rank_id not in state_dict:
80+
logger.warning(
81+
f"DataLoader state is empty for dp rank {self.dp_rank}, "
82+
"expected key {self._rank_id}"
83+
)
84+
return
85+
86+
assert self.dp_world_size == state_dict["world_size"], (
87+
"dp_degree is inconsistent before and after checkpoint, "
88+
"dataloader resharding is not supported yet."
89+
)
90+
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
91+
# keep this for backward compatibility.
92+
StatefulDataLoader(self).load_state_dict(
93+
pickle.loads(state_dict[self._rank_id])
94+
)
95+
96+
97+
class DataLoaderBuilder(Protocol):
98+
"""This is a protocol to annoate ``build_dataloader_fn``.
99+
100+
While mypy.extensions provides Arg to annotate the name, it requires another dependency on
101+
mypy-extensions. Mypy also supports this annonation and it is easier to read.
102+
"""
103+
104+
def __call__(
105+
self,
106+
dataset_name: str,
107+
dataset_path: Optional[str],
108+
tokenizer_path: str,
109+
batch_size: int,
110+
seq_len: int,
111+
dp_rank: int,
112+
dp_world_size: int,
113+
) -> BaseDataLoader:
114+
...

torchtitan/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.datasets.hf_datasets import build_hf_data_loader
7+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
88
from torchtitan.datasets.tokenizer import build_tokenizer
99

1010
__all__ = [
11-
"build_hf_data_loader",
11+
"build_hf_dataloader",
1212
"build_tokenizer",
1313
]

torchtitan/datasets/hf_datasets.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,27 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import pickle
87
from dataclasses import dataclass
9-
from typing import Any, Callable, Dict, List, Optional
8+
from typing import Any, Callable, Optional
109

1110
import torch
11+
12+
from datasets import Dataset, load_dataset
13+
from datasets.distributed import split_dataset_by_node
1214
from torch.distributed.checkpoint.stateful import Stateful
1315
from torch.utils.data import IterableDataset
14-
from torchdata.stateful_dataloader import StatefulDataLoader
1516

16-
from torchtitan.datasets.tokenizer import Tokenizer
17+
from torchtitan.datasets.tokenizer import build_tokenizer, Tokenizer
1718
from torchtitan.logging import logger
18-
19-
from datasets import Dataset, load_dataset
20-
from datasets.distributed import split_dataset_by_node
19+
from torchtitan.dataloader import DPDataLoader
2120

2221

2322
def _load_c4_dataset(dataset_path: str):
2423
"""Load C4 dataset with default configuration."""
2524
return load_dataset(dataset_path, name="en", split="train", streaming=True)
2625

2726

28-
def _process_c4_text(sample: Dict[str, Any]) -> str:
27+
def _process_c4_text(sample: dict[str, Any]) -> str:
2928
"""Process C4 dataset sample text."""
3029
return sample["text"]
3130

@@ -75,8 +74,8 @@ def __init__(
7574
dataset_path: Optional[str],
7675
tokenizer: Tokenizer,
7776
seq_len: int = 2048,
78-
world_size: int = 1,
79-
rank: int = 0,
77+
dp_rank: int = 0,
78+
dp_world_size: int = 1,
8079
infinite: bool = False,
8180
) -> None:
8281
# Force lowercase for consistent comparison
@@ -88,15 +87,15 @@ def __init__(
8887
ds = dataset_loader(path)
8988

9089
self.dataset_name = dataset_name
91-
self._data = split_dataset_by_node(ds, rank, world_size)
90+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
9291
self._tokenizer = tokenizer
9392
self.seq_len = seq_len
9493
self.infinite = infinite
9594
self._text_processor = text_processor
9695

9796
# Variables for checkpointing
9897
self._sample_idx = 0
99-
self._all_tokens: List[int] = []
98+
self._all_tokens: list[int] = []
10099

101100
def _get_data_iter(self):
102101
if self._sample_idx == 0:
@@ -142,56 +141,33 @@ def state_dict(self):
142141
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
143142

144143

145-
class DPAwareDataLoader(StatefulDataLoader, Stateful):
146-
"""
147-
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
148-
"""
149-
150-
def __init__(
151-
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
152-
):
153-
super().__init__(hf_ds, batch_size)
154-
self._dp_rank = dp_rank
155-
self._rank_id = f"dp_rank_{dp_rank}"
156-
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
157-
# raise error if dp_word_size does not match.
158-
self._world_size = world_size
159-
160-
def state_dict(self) -> Dict[str, Any]:
161-
# Store state only for dp rank to avoid replicating the same state across other dimensions
162-
return {
163-
self._rank_id: pickle.dumps(super().state_dict()),
164-
"world_size": self._world_size,
165-
}
166-
167-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
168-
# State being empty is valid
169-
if not state_dict:
170-
return
171-
172-
if self._rank_id not in state_dict:
173-
logger.warning(
174-
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
175-
)
176-
return
177-
assert (
178-
self._world_size == state_dict["world_size"]
179-
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
180-
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
181-
182-
183-
def build_hf_data_loader(
144+
def build_hf_dataloader(
184145
dataset_name: str,
185146
dataset_path: Optional[str],
186-
tokenizer: Tokenizer,
147+
tokenizer_path: str,
187148
batch_size: int,
188149
seq_len: int,
189-
world_size: int,
190-
rank: int,
150+
dp_rank: int,
151+
dp_world_size: int,
191152
infinite: bool = True,
192-
):
153+
) -> DPDataLoader:
193154
"""Build a data loader for HuggingFace datasets."""
155+
tokenizer = build_tokenizer("tiktoken", tokenizer_path)
156+
194157
hf_ds = HuggingFaceDataset(
195-
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
158+
dataset_name=dataset_name,
159+
dataset_path=dataset_path,
160+
tokenizer=tokenizer,
161+
seq_len=seq_len,
162+
dp_rank=dp_rank,
163+
dp_world_size=dp_world_size,
164+
infinite=infinite,
165+
)
166+
167+
return DPDataLoader(
168+
dataset=hf_ds,
169+
tokenizer=tokenizer,
170+
dp_rank=dp_rank,
171+
dp_world_size=dp_world_size,
172+
batch_size=batch_size,
196173
)
197-
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)

torchtitan/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,3 @@
88
# Import the built-in models here so that the corresponding register_model_spec()
99
# will be called.
1010
import torchtitan.models.llama # noqa: F401
11-
12-
13-
model_name_to_tokenizer = {"llama3": "tiktoken"}

torchtitan/models/llama/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
from torchtitan.datasets import build_hf_dataloader
910
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
1011
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
1112
from torchtitan.train_spec import register_train_spec, TrainSpec
@@ -65,5 +66,6 @@
6566
pipelining_fn=pipeline_llama,
6667
build_optimizers_fn=build_optimizers,
6768
build_lr_schedulers_fn=build_lr_schedulers,
69+
build_dataloader_fn=build_hf_dataloader,
6870
)
6971
)

torchtitan/train_spec.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
109
from dataclasses import dataclass
11-
from typing import Callable, Dict, Protocol, Type, TypeAlias
10+
from typing import Callable, Protocol, Type, TypeAlias
1211

1312
import torch.nn as nn
1413
from torch.distributed.pipelining.schedules import _PipelineSchedule
1514

1615
from torchtitan.config_manager import JobConfig
16+
from torchtitan.dataloader import DataLoaderBuilder
1717
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
1818

1919

@@ -36,8 +36,7 @@ class ModelProtocol(Protocol):
3636
"""
3737

3838
@staticmethod
39-
def from_model_args(args: BaseModelArgs) -> nn.Module:
40-
...
39+
def from_model_args(args: BaseModelArgs) -> nn.Module: ...
4140

4241

4342
OptimizersBuilder: TypeAlias = Callable[
@@ -49,19 +48,19 @@ def from_model_args(args: BaseModelArgs) -> nn.Module:
4948
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer]
5049

5150

51+
5252
@dataclass
5353
class TrainSpec:
5454
name: str
5555
cls: Type[nn.Module]
56-
config: Dict[str, BaseModelArgs]
56+
config: dict[str, BaseModelArgs]
5757
parallelize_fn: Callable[[nn.Module], None]
5858
pipelining_fn: Callable[
5959
[nn.Module], tuple[_PipelineSchedule, list[nn.Module], bool, bool]
6060
]
6161
build_optimizers_fn: OptimizersBuilder
6262
build_lr_schedulers_fn: LRSchedulersBuilder
63-
64-
# TODO: Add a ``build_dataloader_fn``
63+
build_dataloader_fn: DataLoaderBuilder
6564

6665
# TODO: Add a FQN convert fn to allow users to load checkpoints from
6766
# HuggingFace or other sources that have different FQN conventions.

0 commit comments

Comments
 (0)