Skip to content

Commit

Permalink
Allow users to customize dataloader
Browse files Browse the repository at this point in the history
ghstack-source-id: 55f5b9a93f7e870a3b0494ad85e38e1c25f17982
Pull Request resolved: #836
  • Loading branch information
fegin committed Feb 12, 2025
1 parent 36c6d2f commit 57e7e2d
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 87 deletions.
4 changes: 2 additions & 2 deletions tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import build_tokenizer


Expand Down Expand Up @@ -42,7 +42,7 @@ def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer = build_tokenizer("tiktoken", "./tests/assets/test_tiktoken.model")
return build_hf_data_loader(
return build_hf_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
Expand Down
114 changes: 114 additions & 0 deletions torchtitan/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Protocol

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer


@dataclass
class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.
This is used to enforce that all dataloaders have the two methods from ``Stateful``,
``state_dict()`` and ``load_state_dict()``.
"""

tokenizer: Tokenizer
dp_rank: int
dp_world_size: int
batch_size: int

@abstractmethod
def __iter__(self):
...


class DPDataLoader(StatefulDataLoader, BaseDataLoader):
"""Dataloader that is aware of data parallelism
This dataloader is used to load data in a distributed fashion. It also utilizes
``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
methods such as ``__iter__``.
"""

def __init__(
self,
dataset: IterableDataset,
tokenizer: Tokenizer,
dp_rank: int,
dp_world_size: int,
batch_size: int,
):
BaseDataLoader.__init__(
self,
tokenizer=tokenizer,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
StatefulDataLoader.__init__(self, dataset, batch_size)
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
# We don't have to use pickle as DCP will serialize the state_dict. However,
# we have to keep this for backward compatibility.
self._rank_id: pickle.dumps(StatefulDataLoader(self).state_dict()),
"world_size": self.dp_world_size,
}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# State being empty is valid.
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self.dp_rank}, "
"expected key {self._rank_id}"
)
return

assert self.dp_world_size == state_dict["world_size"], (
"dp_degree is inconsistent before and after checkpoint, "
"dataloader resharding is not supported yet."
)
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
# keep this for backward compatibility.
StatefulDataLoader(self).load_state_dict(
pickle.loads(state_dict[self._rank_id])
)


class DataLoaderBuilder(Protocol):
"""This is a protocol to annoate ``build_dataloader_fn``.
While mypy.extensions provides Arg to annotate the name, it requires another dependency on
mypy-extensions. Mypy also supports this annonation and it is easier to read.
"""

def __call__(
self,
dataset_name: str,
dataset_path: Optional[str],
tokenizer_path: str,
batch_size: int,
seq_len: int,
dp_rank: int,
dp_world_size: int,
) -> BaseDataLoader:
...
4 changes: 2 additions & 2 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"build_hf_dataloader",
"build_tokenizer",
]
90 changes: 33 additions & 57 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.datasets.tokenizer import build_tokenizer, Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torchtitan.dataloader import DPDataLoader


def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
def _process_c4_text(sample: dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]

Expand Down Expand Up @@ -75,8 +74,8 @@ def __init__(
dataset_path: Optional[str],
tokenizer: Tokenizer,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
dp_rank: int = 0,
dp_world_size: int = 1,
infinite: bool = False,
) -> None:
# Force lowercase for consistent comparison
Expand All @@ -88,15 +87,15 @@ def __init__(
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []
self._all_tokens: list[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
Expand Down Expand Up @@ -142,56 +141,33 @@ def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(
self, dp_rank: int, hf_ds: IterableDataset, batch_size: int, world_size: int
):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"
# Data loader resharding is not yet supported, so we need to store the world size to compare during loading
# raise error if dp_word_size does not match.
self._world_size = world_size

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {
self._rank_id: pickle.dumps(super().state_dict()),
"world_size": self._world_size,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}"
)
return
assert (
self._world_size == state_dict["world_size"]
), "dp_degree is inconsistent before and after checkpoint, dataloader resharding is not supported yet."
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
def build_hf_dataloader(
dataset_name: str,
dataset_path: Optional[str],
tokenizer: Tokenizer,
tokenizer_path: str,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
dp_rank: int,
dp_world_size: int,
infinite: bool = True,
):
) -> DPDataLoader:
"""Build a data loader for HuggingFace datasets."""
tokenizer = build_tokenizer("tiktoken", tokenizer_path)

hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
)

return DPDataLoader(
dataset=hf_ds,
tokenizer=tokenizer,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size, world_size=world_size)
3 changes: 0 additions & 3 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa: F401


model_name_to_tokenizer = {"llama3": "tiktoken"}
2 changes: 2 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.datasets import build_hf_dataloader
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.train_spec import register_train_spec, TrainSpec
Expand Down Expand Up @@ -65,5 +66,6 @@
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
)
)
13 changes: 6 additions & 7 deletions torchtitan/train_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, Protocol, Type, TypeAlias
from typing import Callable, Protocol, Type, TypeAlias

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.dataloader import DataLoaderBuilder
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


Expand All @@ -36,8 +36,7 @@ class ModelProtocol(Protocol):
"""

@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
...
def from_model_args(args: BaseModelArgs) -> nn.Module: ...


OptimizersBuilder: TypeAlias = Callable[
Expand All @@ -49,19 +48,19 @@ def from_model_args(args: BaseModelArgs) -> nn.Module:
LRSchedulersBuilder: TypeAlias = Callable[[OptimizersContainer], LRSchedulersContainer]



@dataclass
class TrainSpec:
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
config: dict[str, BaseModelArgs]
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[
[nn.Module], tuple[_PipelineSchedule, list[nn.Module], bool, bool]
]
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder

# TODO: Add a ``build_dataloader_fn``
build_dataloader_fn: DataLoaderBuilder

# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.
Expand Down
Loading

0 comments on commit 57e7e2d

Please sign in to comment.