Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Adds test coverage for NumpyFSLDatasetMixture #155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/olmo_core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
NumpyDatasetBase,
NumpyDatasetConfig,
NumpyFSLDataset,
NumpyFSLDatasetMixture,
NumpyPaddedFSLDataset,
NumpyVSLDataset,
VSLCurriculum,
Expand All @@ -41,6 +42,7 @@
__all__ = [
"NumpyDatasetBase",
"NumpyFSLDataset",
"NumpyFSLDatasetMixture",
"NumpyPaddedFSLDataset",
"NumpyVSLDataset",
"VSLCurriculum",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
__all__ = [
"NumpyDatasetBase",
"NumpyFSLDataset",
"NumpyFSLDatasetMixture",
"NumpyPaddedFSLDataset",
"VSLCurriculum",
"VSLNaturalCurriculum",
Expand Down
65 changes: 65 additions & 0 deletions src/test/data/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
VSLNaturalCurriculum,
)

from .fixtures import get_fsl_mixture


@pytest.mark.parametrize(
"num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size",
Expand Down Expand Up @@ -82,6 +84,69 @@ def get_all_batches() -> List[List[int]]:
assert set(all_tokens) == set(range(len(all_tokens)))


@pytest.mark.parametrize(
"num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size",
[
(100, 4, 2, 2, 2, 8), # 2 instances per batch, 12 instances total
],
)
def test_fsl_mixture_data_loader(
tmp_path: Path,
num_tokens: int,
sequence_length: int,
world_size: int,
num_workers: int,
num_threads: int,
batch_size: int, # in tokens
):
assert batch_size % sequence_length == 0
assert batch_size % world_size == 0
rank_batch_size = batch_size // world_size
assert rank_batch_size > 0
num_batches = num_tokens // batch_size

def get_all_batches() -> List[List[int]]:
all_batches: List[List[int]] = [[] for _ in range(num_batches)]
dataset = get_fsl_mixture(
tmp_path,
vocab_size=32_000,
pad_token_id=-1,
dtype=np.uint16,
sequence_length=sequence_length,
num_tokens=num_tokens,
eos=0,
)
dataset.prepare()
for rank in range(world_size):
data_loader = NumpyFSLDataLoader(
dataset,
global_batch_size=batch_size,
collator=DataCollator(pad_token_id=-1),
shuffle=True,
num_threads=num_threads,
work_dir=tmp_path,
dp_rank=rank,
dp_world_size=world_size,
num_workers=num_workers,
)
data_loader.reshuffle(epoch=1)
batches = list(data_loader)
assert len(batches) == num_batches
for i, batch in enumerate(batches):
for instance in batch["input_ids"]:
all_batches[i].extend(instance.tolist())
return all_batches

all_batches = get_all_batches()
all_tokens = []
assert len(all_batches) == num_batches
for batch in all_batches:
assert len(batch) == batch_size
all_tokens.extend(batch)

assert len(all_tokens) == num_batches * batch_size


@pytest.mark.parametrize(
"num_tokens, sequence_length, num_workers, batch_size",
[
Expand Down
34 changes: 21 additions & 13 deletions src/test/data/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import numpy as np

from olmo_core.data import NumpyDatasetBase, NumpyDatasetConfig, TokenizerConfig
from olmo_core.data import NumpyFSLDatasetMixture, TokenizerConfig
from olmo_core.data.source_mixture import (
SourceMixtureConfig,
SourceMixtureDatasetConfig,
)
from olmo_core.data.types import NumpyDatasetDType

from ..utils import mk_mmaps
from ..utils import mk_mmaps, Mmaps


def get_fsl_mixture(
Expand All @@ -20,7 +20,9 @@ def get_fsl_mixture(
sequence_length: int = 4,
num_tokens: int = 20 * 1000,
eos: int = 0,
) -> NumpyDatasetBase:
vocab_size: int = 32_000,
pad_token_id: int = -1,
) -> NumpyFSLDatasetMixture:
seed = 42
mmap1 = mk_mmaps(
tmp_path, "mmap1", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length
Expand All @@ -30,12 +32,13 @@ def get_fsl_mixture(
)

tokenizer = TokenizerConfig(
vocab_size=32_000,
vocab_size=vocab_size,
eos_token_id=eos,
pad_token_id=-1,
pad_token_id=pad_token_id,
)

mixture_config = SourceMixtureDatasetConfig(
render_tables=False,
max_tokens=num_tokens,
sequence_length=sequence_length,
source_configs=[
Expand All @@ -55,12 +58,17 @@ def get_fsl_mixture(
seed=seed,
)

ds = NumpyDatasetConfig(
source_mixture_config=mixture_config,
mixture = mixture_config.build()
return NumpyFSLDatasetMixture(
*mixture.to_paths(),
seed=mixture.seed,
sequence_length=sequence_length,
tokenizer=tokenizer,
include_instance_metadata=False,
).build()
ds.prepare()

return ds
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
vocab_size=tokenizer.vocab_size,
dtype=dtype,
metadata=None,
include_instance_metadata=None,
generate_doc_lengths=False,
path_offset_index=mixture.to_index(),
)
Loading