Skip to content

Commit

Permalink
/data/ folder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597024032
  • Loading branch information
xingyousong authored and copybara-github committed Jan 9, 2024
1 parent 03edb54 commit c0bf84a
Show file tree
Hide file tree
Showing 41 changed files with 1,804 additions and 29 deletions.
17 changes: 17 additions & 0 deletions optformer/common/data/augmenters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Public imports for augmenters."""

from optformer.common.data.augmenters.base import Augmenter
39 changes: 39 additions & 0 deletions optformer/common/data/augmenters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for augmenting objects during training."""

import abc
from typing import Generic, TypeVar

T = TypeVar('T')


class Augmenter(Generic[T], abc.ABC):
"""Base data augmenter class."""

@abc.abstractmethod
def augment(self, obj: T, /) -> T:
"""Augments the object.
For efficiency, ideally the object should be augmented in-place. Copying
should be used as a last resort.
Args:
obj: Object to be augmented.
Returns:
Augmented object. Could be a reference to the original input object if
modifying in-place.
"""
25 changes: 25 additions & 0 deletions optformer/common/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All import DatasetFns."""

from optformer.common.data.datasets.base import DatasetFn
from optformer.common.data.datasets.distributed import DistributedDatasetFn
from optformer.common.data.datasets.distributed import DistributedSeqioDatasetFn
from optformer.common.data.datasets.featurized import FeaturizedDatasetFn
from optformer.common.data.datasets.generator import GeneratorDatasetFn
from optformer.common.data.datasets.inference import SeqIOInferenceDatasetFn
from optformer.common.data.datasets.inference import T5XInferenceDatasetFn
from optformer.common.data.datasets.shuffling import ShuffleDatasetFn
from optformer.common.data.datasets.wrappers import SeqioDatasetFnFunctor
27 changes: 27 additions & 0 deletions optformer/common/data/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base abstractions for dataset functions."""
from typing import Protocol, TypeVar

import tensorflow as tf

_S = TypeVar('_S')


class DatasetFn(Protocol[_S]):
"""Base Dataset Function class."""

def __call__(self, source: _S) -> tf.data.Dataset:
"""Transforms a source to a TF Dataset."""
119 changes: 119 additions & 0 deletions optformer/common/data/datasets/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Distributed dataset for scaling data-loading to many CPUs."""

from typing import Callable, Optional

from absl import flags
import attrs
from optformer.common.data.datasets import base
import reverb
import seqio
import tensorflow as tf
import tree

# This flag's value will be passed from downstream binaries.
REVERB_ADDRESS = flags.DEFINE_string(
'reverb_address',
None,
'Address of Reverb server, of form `host:port`.',
)

DISABLE_REVERB = flags.DEFINE_bool(
'disable_reverb',
False,
'If true disables distributed Reverb logic (makes wrapper no-op).',
)


@attrs.define
class DistributedDatasetFn(base.DatasetFn[tf.data.Dataset]):
"""Creates a distributed version of a dataset using the Reverb API.
The distributed dataset will request data from a specific table of a reverb
server which collects units of data from multiple clients rather than a file
source, but still needs to know the incoming dtypes/shapes from a template.
"""

table_name: str = attrs.field(init=True)

num_workers_per_iterator: int = attrs.field(default=1, kw_only=True)
max_samples_per_stream: int = attrs.field(default=120, kw_only=True)
max_in_flight_samples_per_worker: int = attrs.field(default=2, kw_only=True)
prefetch: Optional[int] = attrs.field(default=8, kw_only=True)

def __call__(self, source: tf.data.Dataset) -> reverb.TimestepDataset:
"""Creates the distributed Reverb dataset as a server.
Args:
source: Single-process dataset version, used **only as a template** to
obtain dtype/shape information.
Returns:
Reverb server dataset.
"""

template_ds = source

if REVERB_ADDRESS.value is None:
raise ValueError('`reverb_address` flag is still unset!')

ds = reverb.TimestepDataset(
server_address=REVERB_ADDRESS.value,
table=self.table_name,
dtypes=tree.map_structure(lambda x: x.dtype, template_ds.element_spec),
shapes=tree.map_structure(lambda x: x.shape, template_ds.element_spec),
num_workers_per_iterator=self.num_workers_per_iterator,
max_samples_per_stream=self.max_samples_per_stream,
max_in_flight_samples_per_worker=self.max_in_flight_samples_per_worker,
)

# Change output from default `ReplaySample` struct to actual data.
ds = ds.map(lambda rs: rs.data, num_parallel_calls=tf.data.AUTOTUNE)

if self.prefetch is not None:
ds = ds.prefetch(buffer_size=self.prefetch)

options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = True
options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.parallel_batch = True
return ds.with_options(options)


@attrs.define
class DistributedSeqioDatasetFn(seqio.DatasetFnCallable):
"""Creates a distributed dataset whose table name is according to split."""

seqio_dataset_fn: seqio.DatasetFnCallable = attrs.field(init=True)

distributed_dataset_fn_factory: Callable[[str], DistributedDatasetFn] = (
attrs.field(default=DistributedDatasetFn)
)

def __call__(
self, split: str, shuffle_files: bool, seed: Optional[int] = None
) -> tf.data.Dataset:
original_dataset = self.seqio_dataset_fn(split, shuffle_files, seed)
if DISABLE_REVERB.value:
return original_dataset

distributed_dataset_fn = self.distributed_dataset_fn_factory(split)
return distributed_dataset_fn(original_dataset)

# TODO: Previously fixed compatibility w/ SeqIO. Is it still needed?
@property
def __name__(self):
return 'DistributedSeqioDatasetFn'
70 changes: 70 additions & 0 deletions optformer/common/data/datasets/distributed_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import flagsaver
from optformer.common.data import vocabs
from optformer.common.data.datasets import distributed
import reverb
import tensorflow as tf
from absl.testing import absltest


class DistributedDatasetFnTest(absltest.TestCase):

def setUp(self):
super().setUp()
# Setup original dataset.
self.vocab = vocabs.AsciiVocab()
self.raw_data = [{"inputs": "hi", "targets": "bye"}]
self.original_dataset = tf.data.Dataset.from_generator(
lambda: self.raw_data,
output_types={"inputs": tf.string, "targets": tf.string},
output_shapes={"inputs": [], "targets": []},
)

# Setup distributed components.
# Server on separate process/machine.

self.table_name = "test_table"
self.server = reverb.Server(
tables=[
reverb.Table(
name=self.table_name,
sampler=reverb.selectors.Fifo(),
remover=reverb.selectors.Fifo(),
max_size=100000,
rate_limiter=reverb.rate_limiters.MinSize(1),
max_times_sampled=1,
),
]
)
self.server_address = f"localhost:{self.server.port}"

# Separate process/machine.
self.client = reverb.Client(self.server_address)

def test_client_server_interaction(self):
# Separate process/machine (ideally same as model training process)
with flagsaver.flagsaver(reverb_address=self.server_address):
dataset_fn = distributed.DistributedDatasetFn(self.table_name)
self.distributed_dataset = dataset_fn(self.original_dataset)

data = next(self.original_dataset.as_numpy_iterator())
self.client.insert(data, {self.table_name: 1.0})
distributed_data = next(self.distributed_dataset.as_numpy_iterator())
self.assertEqual(data, distributed_data)


if __name__ == "__main__":
absltest.main()
81 changes: 81 additions & 0 deletions optformer/common/data/datasets/featurized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for applying featurizers over datasets."""

from typing import Sequence

from absl import logging
import attrs
from optformer.common.data import featurizers
from optformer.common.data.datasets import base
import tensorflow as tf


@attrs.define
class FeaturizedDatasetFn(base.DatasetFn[tf.data.Dataset]):
"""Featurizes a dataset."""

featurizer: featurizers.Featurizer

def __call__(self, source: tf.data.Dataset) -> tf.data.Dataset:
"""Returns dataset processed via a Featurizer.
Args:
source: Dataset whose unit of data is a valid input to the featurizer.
Returns:
Immediate TF dataset, after applying the Featurizer and filtering out
empty features. Each unit of data will be a `Dict[str, tf.Tensor]`.
"""
ds = source

# Apply the featurizer.
def featurize_fn(s) -> Sequence[tf.Tensor]:
# `tf.numpy_function` requires output type as tuples, not dicts. Shapes
# must also be consistent across all return statements.
#
# First output tuple element indicates if featurizer was successful.
try:
return (True, *self.featurizer.to_features(s).values()) # pytype: disable=bad-return-type # py311-upgrade
except Exception as e: # pylint:disable=broad-exception-caught
logging.exception('Failed to featurize: %s', e)
return (False, *self.featurizer.empty_output.values()) # pytype: disable=bad-return-type # py311-upgrade

t_out = (tf.bool, *self.featurizer.output_types.values())

ds = ds.map(
lambda s: tf.numpy_function(featurize_fn, [s], t_out),
num_parallel_calls=tf.data.AUTOTUNE,
)

# Filter failed results.
ds = ds.filter(lambda success, *_: success)

# NOTE: Downstream tokenization requires inputs w/ known shapes.
def set_shapes(values: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]:
for v in values:
v.set_shape(())
return values

# Drop success boolean, and re-provide shape on each value.
ds = ds.map(
lambda _, *v: set_shapes(v),
num_parallel_calls=tf.data.AUTOTUNE,
)
# Reconstruct the dict from tuple.
return ds.map(
lambda *v: dict(zip(self.featurizer.output_types.keys(), v)),
num_parallel_calls=tf.data.AUTOTUNE,
)
Loading

0 comments on commit c0bf84a

Please sign in to comment.