Skip to content

Commit

Permalink
feat: support infinite dataset and move to logging
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Sep 5, 2024
1 parent a8d2057 commit 6448afb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
18 changes: 15 additions & 3 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# Standard
from dataclasses import dataclass, field
from typing import List, Optional, Union
import logging
import os

# Third Party
from datasets import Dataset, IterableDataset, interleave_datasets
from pyarrow.lib import ArrowInvalid
from tqdm import tqdm
from transformers.utils import logging
import datasets
import torch
import transformers
Expand All @@ -42,7 +41,7 @@
DEFAULT_UNK_TOKEN = "<unk>"


logger = logging.get_logger("sft_trainer")
logger = logging.getLogger(__name__)


def _load_data(data_path, split, streaming, config_kwargs):
Expand Down Expand Up @@ -366,6 +365,11 @@ def __post_init__(self):
) = load_multi_dataset_with_sampling(
data_config=data_config, column_name_options=column_name_options
)
if self.packing:
logger.warning(
"packing is enabled and strictly avoid using packing for non pretraining use cases \
like fine-tuning to avoid cross contamination."
)
if data_config.data_sampler == "tokens_based":
if not self.packing:
raise ValueError(
Expand Down Expand Up @@ -395,6 +399,13 @@ def __post_init__(self):
cache_dir=self.cache_dir,
use_fast=True,
)
if self.max_steps > 0:
logger.warning(
f"dataset will be iterated infinitely until max_steps {self.max_steps} is met."
)
logger.warning(
f"num_train_epochs {self.num_train_epochs} is ignored by the trainer"
)
self.train_dataset = ConstantLengthHybridDataset(
train_datasets,
train_probs,
Expand All @@ -405,6 +416,7 @@ def __post_init__(self):
self.dataset_text_field,
self.add_bos_token,
self.add_eos_token,
True if self.max_steps > 0 else False,
)
if validation_datasets:
self.validation_dataset = ConstantLengthHybridDataset(
Expand Down
25 changes: 21 additions & 4 deletions tuning/utils/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Standard
from typing import List, Union
import logging

# Third Party
from datasets import Dataset
from datasets import IterableDataset as HFIterableDataset
from datasets import interleave_datasets
from torch.utils.data import IterableDataset
from transformers.utils import logging
import torch

logger = logging.get_logger("transformers")
logger = logging.getLogger(__name__)


class ConstantLengthHybridDataset(
Expand All @@ -26,6 +26,7 @@ def __init__( # pylint: disable=super-init-not-called
text_field="contents",
add_bos_token=True,
add_eos_token=True,
infinite=False,
):
"""packing for pretokenized datasets for pretraining only
since all tokens are attended upon packing.
Expand All @@ -47,6 +48,8 @@ def __init__( # pylint: disable=super-init-not-called
Defaults to True.
add_eos_token (bool, optional): add eos token at the end of each sample.
Defaults to True.
infinite (`bool`, *optional*, defaults to `False`):
If True the iterator is reset after dataset reaches end else stops.
"""
self.datasets = datasets
self.sampling_probs = sampling_probs
Expand All @@ -62,6 +65,12 @@ def __init__( # pylint: disable=super-init-not-called
self.add_eos_token = add_eos_token
self.dataset = interleave_datasets(datasets=self.datasets, split="train")
self.column_names = self.dataset.column_names
self.infinite = infinite
if self.infinite:
logger.warning(
"samples will be provided infinitely.\
Datasets that are exhausted will be reiterated from start."
)
# self._info = self.dataset._info
# self._epoch = 0
logger.warning("add_bos_token: {}".format(self.add_bos_token))
Expand Down Expand Up @@ -125,8 +134,16 @@ def __iter__(self):
)
buffer_len = len(buffer)
except StopIteration:
more_examples = False
break
if self.infinite:
iterators[dataset_id_which_needs_more_tokens] = iter(
self.datasets[dataset_id_which_needs_more_tokens]
)
logger.warning(
"iterator is reset for one of the datasets since it is exhausted."
)
else:
more_examples = False
break
all_token_ids = buffer
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
Expand Down

0 comments on commit 6448afb

Please sign in to comment.