-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: c9999c58aab4c0148e3a9f882e19cdc7eb29b55d Pull Request resolved: #836
- Loading branch information
Showing
9 changed files
with
159 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
import pickle | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, TypeAlias | ||
|
||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import IterableDataset | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
|
||
from torchtitan.datasets.tokenizer import Tokenizer | ||
|
||
|
||
class BaseDataLoader(Stateful, ABC): | ||
"""Base class for all dataloaders. | ||
This is used to enforce that all dataloaders have the methods defined in ``Stateful``, | ||
``state_dict()`` and ``load_state_dict()``. | ||
""" | ||
|
||
@abstractmethod | ||
def __iter__(self): | ||
... | ||
|
||
|
||
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): | ||
"""Dataloader that is aware of distributed data parallelism. | ||
This dataloader is used to load data in a distributed data parallel fashion. It also | ||
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary | ||
methods such as ``__iter__``. | ||
Args: | ||
dataset (IterableDataset): The dataset to iterate over. | ||
tokenizer (Tokenizer): The tokenizer to use to tokenize the dataset. | ||
dp_rank: Data parallelism rank for this dataloader. | ||
dp_world_size: The world size of the data parallelism. | ||
batch_size: The batch size to use for each iteration. | ||
""" | ||
|
||
tokenizer: Tokenizer | ||
dp_rank: int | ||
dp_world_size: int | ||
batch_size: int | ||
|
||
def __init__( | ||
self, | ||
dataset: IterableDataset, | ||
tokenizer: Tokenizer, | ||
dp_rank: int, | ||
dp_world_size: int, | ||
batch_size: int, | ||
): | ||
self.dp_world_size = dp_world_size | ||
self.dp_rank = dp_rank | ||
self.batch_size = batch_size | ||
self.tokenizer = tokenizer | ||
super().__init__(dataset, batch_size) | ||
self._rank_id = f"dp_rank_{dp_rank}" | ||
|
||
def state_dict(self) -> dict[str, Any]: | ||
# Store state only for dp rank to avoid replicating the same state across other dimensions. | ||
return { | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, | ||
# we have to keep this for backward compatibility. | ||
self._rank_id: pickle.dumps(super().state_dict()), | ||
"world_size": self.dp_world_size, | ||
} | ||
|
||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
# State being empty is valid. | ||
if not state_dict: | ||
return | ||
|
||
if self._rank_id not in state_dict: | ||
logger.warning( | ||
f"DataLoader state is empty for dp rank {self.dp_rank}, " | ||
"expected key {self._rank_id}" | ||
) | ||
return | ||
|
||
assert self.dp_world_size == state_dict["world_size"], ( | ||
"dp_degree is inconsistent before and after checkpoint, " | ||
"dataloader resharding is not supported yet." | ||
) | ||
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to | ||
# keep this for backward compatibility. | ||
super().load_state_dict(pickle.loads(state_dict[self._rank_id])) | ||
|
||
|
||
DataLoaderBuilder: TypeAlias = Callable[[...], BaseDataLoader] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.