From d19dae5d6ec034cb4879799dd5dc6e44df6de76e Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 3 Feb 2025 15:37:37 -0800 Subject: [PATCH] Adds test coverage for NumpyFSLDatasetMixture --- src/olmo_core/data/__init__.py | 2 + src/olmo_core/data/numpy_dataset.py | 1 + src/test/data/data_loader_test.py | 65 +++++++++++++++++++++++++++++ src/test/data/fixtures.py | 34 +++++++++------ 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index b710100e1..ce06c9983 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -25,6 +25,7 @@ NumpyDatasetBase, NumpyDatasetConfig, NumpyFSLDataset, + NumpyFSLDatasetMixture, NumpyPaddedFSLDataset, NumpyVSLDataset, VSLCurriculum, @@ -41,6 +42,7 @@ __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", + "NumpyFSLDatasetMixture", "NumpyPaddedFSLDataset", "NumpyVSLDataset", "VSLCurriculum", diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 52e2fc854..cc7794356 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -55,6 +55,7 @@ __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", + "NumpyFSLDatasetMixture", "NumpyPaddedFSLDataset", "VSLCurriculum", "VSLNaturalCurriculum", diff --git a/src/test/data/data_loader_test.py b/src/test/data/data_loader_test.py index 3858b6055..3cd08d51f 100644 --- a/src/test/data/data_loader_test.py +++ b/src/test/data/data_loader_test.py @@ -15,6 +15,8 @@ VSLNaturalCurriculum, ) +from .fixtures import get_fsl_mixture + @pytest.mark.parametrize( "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", @@ -82,6 +84,69 @@ def get_all_batches() -> List[List[int]]: assert set(all_tokens) == set(range(len(all_tokens))) +@pytest.mark.parametrize( + "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", + [ + (100, 4, 2, 2, 2, 8), # 2 instances per batch, 12 instances total + ], +) +def test_fsl_mixture_data_loader( + tmp_path: Path, + num_tokens: int, + sequence_length: int, + world_size: int, + num_workers: int, + num_threads: int, + batch_size: int, # in tokens +): + assert batch_size % sequence_length == 0 + assert batch_size % world_size == 0 + rank_batch_size = batch_size // world_size + assert rank_batch_size > 0 + num_batches = num_tokens // batch_size + + def get_all_batches() -> List[List[int]]: + all_batches: List[List[int]] = [[] for _ in range(num_batches)] + dataset = get_fsl_mixture( + tmp_path, + vocab_size=32_000, + pad_token_id=-1, + dtype=np.uint16, + sequence_length=sequence_length, + num_tokens=num_tokens, + eos=0, + ) + dataset.prepare() + for rank in range(world_size): + data_loader = NumpyFSLDataLoader( + dataset, + global_batch_size=batch_size, + collator=DataCollator(pad_token_id=-1), + shuffle=True, + num_threads=num_threads, + work_dir=tmp_path, + dp_rank=rank, + dp_world_size=world_size, + num_workers=num_workers, + ) + data_loader.reshuffle(epoch=1) + batches = list(data_loader) + assert len(batches) == num_batches + for i, batch in enumerate(batches): + for instance in batch["input_ids"]: + all_batches[i].extend(instance.tolist()) + return all_batches + + all_batches = get_all_batches() + all_tokens = [] + assert len(all_batches) == num_batches + for batch in all_batches: + assert len(batch) == batch_size + all_tokens.extend(batch) + + assert len(all_tokens) == num_batches * batch_size + + @pytest.mark.parametrize( "num_tokens, sequence_length, num_workers, batch_size", [ diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py index 4fcaa84d7..8fd46118f 100644 --- a/src/test/data/fixtures.py +++ b/src/test/data/fixtures.py @@ -3,14 +3,14 @@ import numpy as np -from olmo_core.data import NumpyDatasetBase, NumpyDatasetConfig, TokenizerConfig +from olmo_core.data import NumpyFSLDatasetMixture, TokenizerConfig from olmo_core.data.source_mixture import ( SourceMixtureConfig, SourceMixtureDatasetConfig, ) from olmo_core.data.types import NumpyDatasetDType -from ..utils import mk_mmaps +from ..utils import mk_mmaps, Mmaps def get_fsl_mixture( @@ -20,7 +20,9 @@ def get_fsl_mixture( sequence_length: int = 4, num_tokens: int = 20 * 1000, eos: int = 0, -) -> NumpyDatasetBase: + vocab_size: int = 32_000, + pad_token_id: int = -1, +) -> NumpyFSLDatasetMixture: seed = 42 mmap1 = mk_mmaps( tmp_path, "mmap1", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length @@ -30,12 +32,13 @@ def get_fsl_mixture( ) tokenizer = TokenizerConfig( - vocab_size=32_000, + vocab_size=vocab_size, eos_token_id=eos, - pad_token_id=-1, + pad_token_id=pad_token_id, ) mixture_config = SourceMixtureDatasetConfig( + render_tables=False, max_tokens=num_tokens, sequence_length=sequence_length, source_configs=[ @@ -55,12 +58,17 @@ def get_fsl_mixture( seed=seed, ) - ds = NumpyDatasetConfig( - source_mixture_config=mixture_config, + mixture = mixture_config.build() + return NumpyFSLDatasetMixture( + *mixture.to_paths(), + seed=mixture.seed, sequence_length=sequence_length, - tokenizer=tokenizer, - include_instance_metadata=False, - ).build() - ds.prepare() - - return ds + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + vocab_size=tokenizer.vocab_size, + dtype=dtype, + metadata=None, + include_instance_metadata=None, + generate_doc_lengths=False, + path_offset_index=mixture.to_index(), + )