-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
41 changed files
with
1,804 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Public imports for augmenters.""" | ||
|
||
from optformer.common.data.augmenters.base import Augmenter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Base class for augmenting objects during training.""" | ||
|
||
import abc | ||
from typing import Generic, TypeVar | ||
|
||
T = TypeVar('T') | ||
|
||
|
||
class Augmenter(Generic[T], abc.ABC): | ||
"""Base data augmenter class.""" | ||
|
||
@abc.abstractmethod | ||
def augment(self, obj: T, /) -> T: | ||
"""Augments the object. | ||
For efficiency, ideally the object should be augmented in-place. Copying | ||
should be used as a last resort. | ||
Args: | ||
obj: Object to be augmented. | ||
Returns: | ||
Augmented object. Could be a reference to the original input object if | ||
modifying in-place. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""All import DatasetFns.""" | ||
|
||
from optformer.common.data.datasets.base import DatasetFn | ||
from optformer.common.data.datasets.distributed import DistributedDatasetFn | ||
from optformer.common.data.datasets.distributed import DistributedSeqioDatasetFn | ||
from optformer.common.data.datasets.featurized import FeaturizedDatasetFn | ||
from optformer.common.data.datasets.generator import GeneratorDatasetFn | ||
from optformer.common.data.datasets.inference import SeqIOInferenceDatasetFn | ||
from optformer.common.data.datasets.inference import T5XInferenceDatasetFn | ||
from optformer.common.data.datasets.shuffling import ShuffleDatasetFn | ||
from optformer.common.data.datasets.wrappers import SeqioDatasetFnFunctor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Base abstractions for dataset functions.""" | ||
from typing import Protocol, TypeVar | ||
|
||
import tensorflow as tf | ||
|
||
_S = TypeVar('_S') | ||
|
||
|
||
class DatasetFn(Protocol[_S]): | ||
"""Base Dataset Function class.""" | ||
|
||
def __call__(self, source: _S) -> tf.data.Dataset: | ||
"""Transforms a source to a TF Dataset.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Distributed dataset for scaling data-loading to many CPUs.""" | ||
|
||
from typing import Callable, Optional | ||
|
||
from absl import flags | ||
import attrs | ||
from optformer.common.data.datasets import base | ||
import reverb | ||
import seqio | ||
import tensorflow as tf | ||
import tree | ||
|
||
# This flag's value will be passed from downstream binaries. | ||
REVERB_ADDRESS = flags.DEFINE_string( | ||
'reverb_address', | ||
None, | ||
'Address of Reverb server, of form `host:port`.', | ||
) | ||
|
||
DISABLE_REVERB = flags.DEFINE_bool( | ||
'disable_reverb', | ||
False, | ||
'If true disables distributed Reverb logic (makes wrapper no-op).', | ||
) | ||
|
||
|
||
@attrs.define | ||
class DistributedDatasetFn(base.DatasetFn[tf.data.Dataset]): | ||
"""Creates a distributed version of a dataset using the Reverb API. | ||
The distributed dataset will request data from a specific table of a reverb | ||
server which collects units of data from multiple clients rather than a file | ||
source, but still needs to know the incoming dtypes/shapes from a template. | ||
""" | ||
|
||
table_name: str = attrs.field(init=True) | ||
|
||
num_workers_per_iterator: int = attrs.field(default=1, kw_only=True) | ||
max_samples_per_stream: int = attrs.field(default=120, kw_only=True) | ||
max_in_flight_samples_per_worker: int = attrs.field(default=2, kw_only=True) | ||
prefetch: Optional[int] = attrs.field(default=8, kw_only=True) | ||
|
||
def __call__(self, source: tf.data.Dataset) -> reverb.TimestepDataset: | ||
"""Creates the distributed Reverb dataset as a server. | ||
Args: | ||
source: Single-process dataset version, used **only as a template** to | ||
obtain dtype/shape information. | ||
Returns: | ||
Reverb server dataset. | ||
""" | ||
|
||
template_ds = source | ||
|
||
if REVERB_ADDRESS.value is None: | ||
raise ValueError('`reverb_address` flag is still unset!') | ||
|
||
ds = reverb.TimestepDataset( | ||
server_address=REVERB_ADDRESS.value, | ||
table=self.table_name, | ||
dtypes=tree.map_structure(lambda x: x.dtype, template_ds.element_spec), | ||
shapes=tree.map_structure(lambda x: x.shape, template_ds.element_spec), | ||
num_workers_per_iterator=self.num_workers_per_iterator, | ||
max_samples_per_stream=self.max_samples_per_stream, | ||
max_in_flight_samples_per_worker=self.max_in_flight_samples_per_worker, | ||
) | ||
|
||
# Change output from default `ReplaySample` struct to actual data. | ||
ds = ds.map(lambda rs: rs.data, num_parallel_calls=tf.data.AUTOTUNE) | ||
|
||
if self.prefetch is not None: | ||
ds = ds.prefetch(buffer_size=self.prefetch) | ||
|
||
options = tf.data.Options() | ||
options.experimental_optimization.apply_default_optimizations = True | ||
options.experimental_optimization.map_and_batch_fusion = True | ||
options.experimental_optimization.parallel_batch = True | ||
return ds.with_options(options) | ||
|
||
|
||
@attrs.define | ||
class DistributedSeqioDatasetFn(seqio.DatasetFnCallable): | ||
"""Creates a distributed dataset whose table name is according to split.""" | ||
|
||
seqio_dataset_fn: seqio.DatasetFnCallable = attrs.field(init=True) | ||
|
||
distributed_dataset_fn_factory: Callable[[str], DistributedDatasetFn] = ( | ||
attrs.field(default=DistributedDatasetFn) | ||
) | ||
|
||
def __call__( | ||
self, split: str, shuffle_files: bool, seed: Optional[int] = None | ||
) -> tf.data.Dataset: | ||
original_dataset = self.seqio_dataset_fn(split, shuffle_files, seed) | ||
if DISABLE_REVERB.value: | ||
return original_dataset | ||
|
||
distributed_dataset_fn = self.distributed_dataset_fn_factory(split) | ||
return distributed_dataset_fn(original_dataset) | ||
|
||
# TODO: Previously fixed compatibility w/ SeqIO. Is it still needed? | ||
@property | ||
def __name__(self): | ||
return 'DistributedSeqioDatasetFn' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import flagsaver | ||
from optformer.common.data import vocabs | ||
from optformer.common.data.datasets import distributed | ||
import reverb | ||
import tensorflow as tf | ||
from absl.testing import absltest | ||
|
||
|
||
class DistributedDatasetFnTest(absltest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
# Setup original dataset. | ||
self.vocab = vocabs.AsciiVocab() | ||
self.raw_data = [{"inputs": "hi", "targets": "bye"}] | ||
self.original_dataset = tf.data.Dataset.from_generator( | ||
lambda: self.raw_data, | ||
output_types={"inputs": tf.string, "targets": tf.string}, | ||
output_shapes={"inputs": [], "targets": []}, | ||
) | ||
|
||
# Setup distributed components. | ||
# Server on separate process/machine. | ||
|
||
self.table_name = "test_table" | ||
self.server = reverb.Server( | ||
tables=[ | ||
reverb.Table( | ||
name=self.table_name, | ||
sampler=reverb.selectors.Fifo(), | ||
remover=reverb.selectors.Fifo(), | ||
max_size=100000, | ||
rate_limiter=reverb.rate_limiters.MinSize(1), | ||
max_times_sampled=1, | ||
), | ||
] | ||
) | ||
self.server_address = f"localhost:{self.server.port}" | ||
|
||
# Separate process/machine. | ||
self.client = reverb.Client(self.server_address) | ||
|
||
def test_client_server_interaction(self): | ||
# Separate process/machine (ideally same as model training process) | ||
with flagsaver.flagsaver(reverb_address=self.server_address): | ||
dataset_fn = distributed.DistributedDatasetFn(self.table_name) | ||
self.distributed_dataset = dataset_fn(self.original_dataset) | ||
|
||
data = next(self.original_dataset.as_numpy_iterator()) | ||
self.client.insert(data, {self.table_name: 1.0}) | ||
distributed_data = next(self.distributed_dataset.as_numpy_iterator()) | ||
self.assertEqual(data, distributed_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Classes for applying featurizers over datasets.""" | ||
|
||
from typing import Sequence | ||
|
||
from absl import logging | ||
import attrs | ||
from optformer.common.data import featurizers | ||
from optformer.common.data.datasets import base | ||
import tensorflow as tf | ||
|
||
|
||
@attrs.define | ||
class FeaturizedDatasetFn(base.DatasetFn[tf.data.Dataset]): | ||
"""Featurizes a dataset.""" | ||
|
||
featurizer: featurizers.Featurizer | ||
|
||
def __call__(self, source: tf.data.Dataset) -> tf.data.Dataset: | ||
"""Returns dataset processed via a Featurizer. | ||
Args: | ||
source: Dataset whose unit of data is a valid input to the featurizer. | ||
Returns: | ||
Immediate TF dataset, after applying the Featurizer and filtering out | ||
empty features. Each unit of data will be a `Dict[str, tf.Tensor]`. | ||
""" | ||
ds = source | ||
|
||
# Apply the featurizer. | ||
def featurize_fn(s) -> Sequence[tf.Tensor]: | ||
# `tf.numpy_function` requires output type as tuples, not dicts. Shapes | ||
# must also be consistent across all return statements. | ||
# | ||
# First output tuple element indicates if featurizer was successful. | ||
try: | ||
return (True, *self.featurizer.to_features(s).values()) # pytype: disable=bad-return-type # py311-upgrade | ||
except Exception as e: # pylint:disable=broad-exception-caught | ||
logging.exception('Failed to featurize: %s', e) | ||
return (False, *self.featurizer.empty_output.values()) # pytype: disable=bad-return-type # py311-upgrade | ||
|
||
t_out = (tf.bool, *self.featurizer.output_types.values()) | ||
|
||
ds = ds.map( | ||
lambda s: tf.numpy_function(featurize_fn, [s], t_out), | ||
num_parallel_calls=tf.data.AUTOTUNE, | ||
) | ||
|
||
# Filter failed results. | ||
ds = ds.filter(lambda success, *_: success) | ||
|
||
# NOTE: Downstream tokenization requires inputs w/ known shapes. | ||
def set_shapes(values: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]: | ||
for v in values: | ||
v.set_shape(()) | ||
return values | ||
|
||
# Drop success boolean, and re-provide shape on each value. | ||
ds = ds.map( | ||
lambda _, *v: set_shapes(v), | ||
num_parallel_calls=tf.data.AUTOTUNE, | ||
) | ||
# Reconstruct the dict from tuple. | ||
return ds.map( | ||
lambda *v: dict(zip(self.featurizer.output_types.keys(), v)), | ||
num_parallel_calls=tf.data.AUTOTUNE, | ||
) |
Oops, something went wrong.