Skip to content

Commit

Permalink
Avoid using skip() in hf_datasets
Browse files Browse the repository at this point in the history
We found out that skip() may not behave as expected. This is a workaround solution while we are investiging the root cause.

ghstack-source-id: 412810bc4a83de39df3fdfe62bc249d80c7deb19
Pull Request resolved: #835
  • Loading branch information
fegin committed Feb 12, 2025
1 parent 189c9f0 commit 7099098
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from typing import Any, Callable, Dict, List, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node


def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
Expand Down Expand Up @@ -99,13 +99,15 @@ def __init__(
self._all_tokens: List[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))
it = iter(self._data)

for _ in range(self._sample_idx):
next(self._data.next())

return it

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
Expand Down

0 comments on commit 7099098

Please sign in to comment.