Skip to content

Commit 7099098

Browse files
committed
Avoid using skip() in hf_datasets
We found out that skip() may not behave as expected. This is a workaround solution while we are investiging the root cause. ghstack-source-id: 412810b Pull Request resolved: #835
1 parent 189c9f0 commit 7099098

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

torchtitan/datasets/hf_datasets.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
from typing import Any, Callable, Dict, List, Optional
1010

1111
import torch
12+
13+
from datasets import Dataset, load_dataset
14+
from datasets.distributed import split_dataset_by_node
1215
from torch.distributed.checkpoint.stateful import Stateful
1316
from torch.utils.data import IterableDataset
1417
from torchdata.stateful_dataloader import StatefulDataLoader
1518

1619
from torchtitan.datasets.tokenizer import Tokenizer
1720
from torchtitan.logging import logger
1821

19-
from datasets import Dataset, load_dataset
20-
from datasets.distributed import split_dataset_by_node
21-
2222

2323
def _load_c4_dataset(dataset_path: str):
2424
"""Load C4 dataset with default configuration."""
@@ -99,13 +99,15 @@ def __init__(
9999
self._all_tokens: List[int] = []
100100

101101
def _get_data_iter(self):
102-
if self._sample_idx == 0:
103-
return iter(self._data)
104-
105102
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
106103
return iter([])
107104

108-
return iter(self._data.skip(self._sample_idx))
105+
it = iter(self._data)
106+
107+
for _ in range(self._sample_idx):
108+
next(self._data.next())
109+
110+
return it
109111

110112
def __iter__(self):
111113
max_buffer_token_len = 1 + self.seq_len

0 commit comments

Comments
 (0)