From c0bf84a5c0ab93b5e73316214eb9ac61a77d5f4a Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Tue, 9 Jan 2024 13:13:54 -0800 Subject: [PATCH] /data/ folder. PiperOrigin-RevId: 597024032 --- optformer/common/data/augmenters/__init__.py | 17 +++ optformer/common/data/augmenters/base.py | 39 ++++++ optformer/common/data/datasets/__init__.py | 25 ++++ optformer/common/data/datasets/base.py | 27 ++++ optformer/common/data/datasets/distributed.py | 119 ++++++++++++++++ .../common/data/datasets/distributed_test.py | 70 ++++++++++ optformer/common/data/datasets/featurized.py | 81 +++++++++++ .../common/data/datasets/featurized_test.py | 45 ++++++ optformer/common/data/datasets/generator.py | 42 ++++++ .../common/data/datasets/generator_test.py | 58 ++++++++ optformer/common/data/datasets/inference.py | 126 +++++++++++++++++ .../common/data/datasets/inference_test.py | 117 ++++++++++++++++ optformer/common/data/datasets/shuffling.py | 34 +++++ optformer/common/data/datasets/wrappers.py | 49 +++++++ optformer/common/data/featurizers/__init__.py | 18 +++ optformer/common/data/featurizers/base.py | 59 ++++++++ optformer/common/data/featurizers/testing.py | 50 +++++++ optformer/common/data/filters/__init__.py | 18 +++ optformer/common/data/filters/base.py | 39 ++++++ optformer/common/data/filters/features.py | 48 +++++++ .../common/data/filters/features_test.py | 41 ++++++ optformer/common/data/processors/__init__.py | 20 +++ optformer/common/data/processors/base.py | 37 +++++ optformer/common/data/processors/masking.py | 87 ++++++++++++ .../common/data/processors/masking_test.py | 57 ++++++++ .../common/data/processors/partitioning.py | 46 +++++++ .../data/processors/partitioning_test.py | 36 +++++ optformer/common/data/vocabs/__init__.py | 17 +++ optformer/common/data/vocabs/ascii.py | 61 ++++++++ optformer/common/data/vocabs/ascii_test.py | 57 ++++++++ optformer/{ => common}/inference/decoding.py | 2 +- .../{ => common}/inference/sequence_utils.py | 0 .../inference/sequence_utils_test.py | 2 +- optformer/common/serialization/__init__.py | 25 ++++ optformer/{ => common}/serialization/base.py | 0 .../common/serialization/numeric/__init__.py | 17 +++ .../common/serialization/numeric/tokens.py | 130 ++++++++++++++++++ .../serialization/numeric/tokens_test.py | 88 ++++++++++++ .../{ => common}/serialization/tokens.py | 2 +- .../{ => common}/serialization/tokens_test.py | 2 +- optformer/serialization/__init__.py | 25 ---- 41 files changed, 1804 insertions(+), 29 deletions(-) create mode 100644 optformer/common/data/augmenters/__init__.py create mode 100644 optformer/common/data/augmenters/base.py create mode 100644 optformer/common/data/datasets/__init__.py create mode 100644 optformer/common/data/datasets/base.py create mode 100644 optformer/common/data/datasets/distributed.py create mode 100644 optformer/common/data/datasets/distributed_test.py create mode 100644 optformer/common/data/datasets/featurized.py create mode 100644 optformer/common/data/datasets/featurized_test.py create mode 100644 optformer/common/data/datasets/generator.py create mode 100644 optformer/common/data/datasets/generator_test.py create mode 100644 optformer/common/data/datasets/inference.py create mode 100644 optformer/common/data/datasets/inference_test.py create mode 100644 optformer/common/data/datasets/shuffling.py create mode 100644 optformer/common/data/datasets/wrappers.py create mode 100644 optformer/common/data/featurizers/__init__.py create mode 100644 optformer/common/data/featurizers/base.py create mode 100644 optformer/common/data/featurizers/testing.py create mode 100644 optformer/common/data/filters/__init__.py create mode 100644 optformer/common/data/filters/base.py create mode 100644 optformer/common/data/filters/features.py create mode 100644 optformer/common/data/filters/features_test.py create mode 100644 optformer/common/data/processors/__init__.py create mode 100644 optformer/common/data/processors/base.py create mode 100644 optformer/common/data/processors/masking.py create mode 100644 optformer/common/data/processors/masking_test.py create mode 100644 optformer/common/data/processors/partitioning.py create mode 100644 optformer/common/data/processors/partitioning_test.py create mode 100644 optformer/common/data/vocabs/__init__.py create mode 100644 optformer/common/data/vocabs/ascii.py create mode 100644 optformer/common/data/vocabs/ascii_test.py rename optformer/{ => common}/inference/decoding.py (97%) rename optformer/{ => common}/inference/sequence_utils.py (100%) rename optformer/{ => common}/inference/sequence_utils_test.py (99%) create mode 100644 optformer/common/serialization/__init__.py rename optformer/{ => common}/serialization/base.py (100%) create mode 100644 optformer/common/serialization/numeric/__init__.py create mode 100644 optformer/common/serialization/numeric/tokens.py create mode 100644 optformer/common/serialization/numeric/tokens_test.py rename optformer/{ => common}/serialization/tokens.py (99%) rename optformer/{ => common}/serialization/tokens_test.py (98%) delete mode 100644 optformer/serialization/__init__.py diff --git a/optformer/common/data/augmenters/__init__.py b/optformer/common/data/augmenters/__init__.py new file mode 100644 index 0000000..e3603da --- /dev/null +++ b/optformer/common/data/augmenters/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public imports for augmenters.""" + +from optformer.common.data.augmenters.base import Augmenter diff --git a/optformer/common/data/augmenters/base.py b/optformer/common/data/augmenters/base.py new file mode 100644 index 0000000..b283263 --- /dev/null +++ b/optformer/common/data/augmenters/base.py @@ -0,0 +1,39 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for augmenting objects during training.""" + +import abc +from typing import Generic, TypeVar + +T = TypeVar('T') + + +class Augmenter(Generic[T], abc.ABC): + """Base data augmenter class.""" + + @abc.abstractmethod + def augment(self, obj: T, /) -> T: + """Augments the object. + + For efficiency, ideally the object should be augmented in-place. Copying + should be used as a last resort. + + Args: + obj: Object to be augmented. + + Returns: + Augmented object. Could be a reference to the original input object if + modifying in-place. + """ diff --git a/optformer/common/data/datasets/__init__.py b/optformer/common/data/datasets/__init__.py new file mode 100644 index 0000000..4290908 --- /dev/null +++ b/optformer/common/data/datasets/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All import DatasetFns.""" + +from optformer.common.data.datasets.base import DatasetFn +from optformer.common.data.datasets.distributed import DistributedDatasetFn +from optformer.common.data.datasets.distributed import DistributedSeqioDatasetFn +from optformer.common.data.datasets.featurized import FeaturizedDatasetFn +from optformer.common.data.datasets.generator import GeneratorDatasetFn +from optformer.common.data.datasets.inference import SeqIOInferenceDatasetFn +from optformer.common.data.datasets.inference import T5XInferenceDatasetFn +from optformer.common.data.datasets.shuffling import ShuffleDatasetFn +from optformer.common.data.datasets.wrappers import SeqioDatasetFnFunctor diff --git a/optformer/common/data/datasets/base.py b/optformer/common/data/datasets/base.py new file mode 100644 index 0000000..0079685 --- /dev/null +++ b/optformer/common/data/datasets/base.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base abstractions for dataset functions.""" +from typing import Protocol, TypeVar + +import tensorflow as tf + +_S = TypeVar('_S') + + +class DatasetFn(Protocol[_S]): + """Base Dataset Function class.""" + + def __call__(self, source: _S) -> tf.data.Dataset: + """Transforms a source to a TF Dataset.""" diff --git a/optformer/common/data/datasets/distributed.py b/optformer/common/data/datasets/distributed.py new file mode 100644 index 0000000..3e1e227 --- /dev/null +++ b/optformer/common/data/datasets/distributed.py @@ -0,0 +1,119 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Distributed dataset for scaling data-loading to many CPUs.""" + +from typing import Callable, Optional + +from absl import flags +import attrs +from optformer.common.data.datasets import base +import reverb +import seqio +import tensorflow as tf +import tree + +# This flag's value will be passed from downstream binaries. +REVERB_ADDRESS = flags.DEFINE_string( + 'reverb_address', + None, + 'Address of Reverb server, of form `host:port`.', +) + +DISABLE_REVERB = flags.DEFINE_bool( + 'disable_reverb', + False, + 'If true disables distributed Reverb logic (makes wrapper no-op).', +) + + +@attrs.define +class DistributedDatasetFn(base.DatasetFn[tf.data.Dataset]): + """Creates a distributed version of a dataset using the Reverb API. + + The distributed dataset will request data from a specific table of a reverb + server which collects units of data from multiple clients rather than a file + source, but still needs to know the incoming dtypes/shapes from a template. + """ + + table_name: str = attrs.field(init=True) + + num_workers_per_iterator: int = attrs.field(default=1, kw_only=True) + max_samples_per_stream: int = attrs.field(default=120, kw_only=True) + max_in_flight_samples_per_worker: int = attrs.field(default=2, kw_only=True) + prefetch: Optional[int] = attrs.field(default=8, kw_only=True) + + def __call__(self, source: tf.data.Dataset) -> reverb.TimestepDataset: + """Creates the distributed Reverb dataset as a server. + + Args: + source: Single-process dataset version, used **only as a template** to + obtain dtype/shape information. + + Returns: + Reverb server dataset. + """ + + template_ds = source + + if REVERB_ADDRESS.value is None: + raise ValueError('`reverb_address` flag is still unset!') + + ds = reverb.TimestepDataset( + server_address=REVERB_ADDRESS.value, + table=self.table_name, + dtypes=tree.map_structure(lambda x: x.dtype, template_ds.element_spec), + shapes=tree.map_structure(lambda x: x.shape, template_ds.element_spec), + num_workers_per_iterator=self.num_workers_per_iterator, + max_samples_per_stream=self.max_samples_per_stream, + max_in_flight_samples_per_worker=self.max_in_flight_samples_per_worker, + ) + + # Change output from default `ReplaySample` struct to actual data. + ds = ds.map(lambda rs: rs.data, num_parallel_calls=tf.data.AUTOTUNE) + + if self.prefetch is not None: + ds = ds.prefetch(buffer_size=self.prefetch) + + options = tf.data.Options() + options.experimental_optimization.apply_default_optimizations = True + options.experimental_optimization.map_and_batch_fusion = True + options.experimental_optimization.parallel_batch = True + return ds.with_options(options) + + +@attrs.define +class DistributedSeqioDatasetFn(seqio.DatasetFnCallable): + """Creates a distributed dataset whose table name is according to split.""" + + seqio_dataset_fn: seqio.DatasetFnCallable = attrs.field(init=True) + + distributed_dataset_fn_factory: Callable[[str], DistributedDatasetFn] = ( + attrs.field(default=DistributedDatasetFn) + ) + + def __call__( + self, split: str, shuffle_files: bool, seed: Optional[int] = None + ) -> tf.data.Dataset: + original_dataset = self.seqio_dataset_fn(split, shuffle_files, seed) + if DISABLE_REVERB.value: + return original_dataset + + distributed_dataset_fn = self.distributed_dataset_fn_factory(split) + return distributed_dataset_fn(original_dataset) + + # TODO: Previously fixed compatibility w/ SeqIO. Is it still needed? + @property + def __name__(self): + return 'DistributedSeqioDatasetFn' diff --git a/optformer/common/data/datasets/distributed_test.py b/optformer/common/data/datasets/distributed_test.py new file mode 100644 index 0000000..32f6c0d --- /dev/null +++ b/optformer/common/data/datasets/distributed_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import flagsaver +from optformer.common.data import vocabs +from optformer.common.data.datasets import distributed +import reverb +import tensorflow as tf +from absl.testing import absltest + + +class DistributedDatasetFnTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Setup original dataset. + self.vocab = vocabs.AsciiVocab() + self.raw_data = [{"inputs": "hi", "targets": "bye"}] + self.original_dataset = tf.data.Dataset.from_generator( + lambda: self.raw_data, + output_types={"inputs": tf.string, "targets": tf.string}, + output_shapes={"inputs": [], "targets": []}, + ) + + # Setup distributed components. + # Server on separate process/machine. + + self.table_name = "test_table" + self.server = reverb.Server( + tables=[ + reverb.Table( + name=self.table_name, + sampler=reverb.selectors.Fifo(), + remover=reverb.selectors.Fifo(), + max_size=100000, + rate_limiter=reverb.rate_limiters.MinSize(1), + max_times_sampled=1, + ), + ] + ) + self.server_address = f"localhost:{self.server.port}" + + # Separate process/machine. + self.client = reverb.Client(self.server_address) + + def test_client_server_interaction(self): + # Separate process/machine (ideally same as model training process) + with flagsaver.flagsaver(reverb_address=self.server_address): + dataset_fn = distributed.DistributedDatasetFn(self.table_name) + self.distributed_dataset = dataset_fn(self.original_dataset) + + data = next(self.original_dataset.as_numpy_iterator()) + self.client.insert(data, {self.table_name: 1.0}) + distributed_data = next(self.distributed_dataset.as_numpy_iterator()) + self.assertEqual(data, distributed_data) + + +if __name__ == "__main__": + absltest.main() diff --git a/optformer/common/data/datasets/featurized.py b/optformer/common/data/datasets/featurized.py new file mode 100644 index 0000000..c199abd --- /dev/null +++ b/optformer/common/data/datasets/featurized.py @@ -0,0 +1,81 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for applying featurizers over datasets.""" + +from typing import Sequence + +from absl import logging +import attrs +from optformer.common.data import featurizers +from optformer.common.data.datasets import base +import tensorflow as tf + + +@attrs.define +class FeaturizedDatasetFn(base.DatasetFn[tf.data.Dataset]): + """Featurizes a dataset.""" + + featurizer: featurizers.Featurizer + + def __call__(self, source: tf.data.Dataset) -> tf.data.Dataset: + """Returns dataset processed via a Featurizer. + + Args: + source: Dataset whose unit of data is a valid input to the featurizer. + + Returns: + Immediate TF dataset, after applying the Featurizer and filtering out + empty features. Each unit of data will be a `Dict[str, tf.Tensor]`. + """ + ds = source + + # Apply the featurizer. + def featurize_fn(s) -> Sequence[tf.Tensor]: + # `tf.numpy_function` requires output type as tuples, not dicts. Shapes + # must also be consistent across all return statements. + # + # First output tuple element indicates if featurizer was successful. + try: + return (True, *self.featurizer.to_features(s).values()) # pytype: disable=bad-return-type # py311-upgrade + except Exception as e: # pylint:disable=broad-exception-caught + logging.exception('Failed to featurize: %s', e) + return (False, *self.featurizer.empty_output.values()) # pytype: disable=bad-return-type # py311-upgrade + + t_out = (tf.bool, *self.featurizer.output_types.values()) + + ds = ds.map( + lambda s: tf.numpy_function(featurize_fn, [s], t_out), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + # Filter failed results. + ds = ds.filter(lambda success, *_: success) + + # NOTE: Downstream tokenization requires inputs w/ known shapes. + def set_shapes(values: Sequence[tf.Tensor]) -> Sequence[tf.Tensor]: + for v in values: + v.set_shape(()) + return values + + # Drop success boolean, and re-provide shape on each value. + ds = ds.map( + lambda _, *v: set_shapes(v), + num_parallel_calls=tf.data.AUTOTUNE, + ) + # Reconstruct the dict from tuple. + return ds.map( + lambda *v: dict(zip(self.featurizer.output_types.keys(), v)), + num_parallel_calls=tf.data.AUTOTUNE, + ) diff --git a/optformer/common/data/datasets/featurized_test.py b/optformer/common/data/datasets/featurized_test.py new file mode 100644 index 0000000..73f8a80 --- /dev/null +++ b/optformer/common/data/datasets/featurized_test.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.data import featurizers +from optformer.common.data.datasets import featurized +import seqio +import tensorflow as tf + +from absl.testing import absltest + +BAD_STRING = 'bad_string' + + +class FeaturizedDatasetFnTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.featurizer = featurizers.IdentityFeaturizer() + self.dataset_fn = featurized.FeaturizedDatasetFn(self.featurizer) + + def test_e2e(self): + objs = ['hello', 'goodbye'] + ds = tf.data.Dataset.from_tensor_slices(objs) + ds = self.dataset_fn(ds) + + expected = [self.featurizer.to_features(s) for s in objs] + + seqio.test_utils.assert_dataset(ds, expected) + for k, v in ds.element_spec.items(): + self.assertSequenceEqual(v.shape, (), msg=f'{k} must have empty shape.') + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/common/data/datasets/generator.py b/optformer/common/data/datasets/generator.py new file mode 100644 index 0000000..46be60b --- /dev/null +++ b/optformer/common/data/datasets/generator.py @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset creation from a generator.""" + +from typing import Iterator, TypeVar + +import attrs +from optformer.common.data import featurizers +from optformer.common.data.datasets import base +import tensorflow as tf + +_S = TypeVar('_S') + + +@attrs.define +class GeneratorDatasetFn(base.DatasetFn[_S]): + """Dataset will load from a featurized generator.""" + + featurizer: featurizers.Featurizer[_S] = attrs.field(init=True, kw_only=True) + + def __call__(self, source: Iterator[_S]) -> tf.data.Dataset: + def _generator(): + for obj in source: + yield self.featurizer.to_features(obj) + + return tf.data.Dataset.from_generator( + _generator, + output_types=self.featurizer.output_types, + output_shapes=self.featurizer.output_shapes, + ) diff --git a/optformer/common/data/datasets/generator_test.py b/optformer/common/data/datasets/generator_test.py new file mode 100644 index 0000000..22f5b3d --- /dev/null +++ b/optformer/common/data/datasets/generator_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from optformer.common.data import featurizers +from optformer.common.data.datasets import generator +import seqio +import tensorflow as tf + +from absl.testing import absltest + + +class DoNothingFeaturizer(featurizers.Featurizer[str]): + + def to_features(self, obj: str) -> Dict[str, tf.Tensor]: + return {'key': tf.constant(obj, dtype=tf.string)} + + @property + def output_types(self) -> Dict[str, tf.DType]: + return {'key': tf.string} + + @property + def output_shapes(self) -> Dict[str, tf.TensorShape]: + return {'key': tf.TensorShape([])} + + @property + def empty_output(self) -> Dict[str, tf.Tensor]: + return {'key': tf.constant('', dtype=tf.string)} + + +class GeneratorDatasetFnTest(absltest.TestCase): + + def test_buffer(self): + buffer = ['hello', 'goodbye'] + gen = (s for s in buffer) + + featurizer = DoNothingFeaturizer() + dataset_fn = generator.GeneratorDatasetFn(featurizer=featurizer) + + dataset = dataset_fn(gen) + expected = [{'key': b'hello'}, {'key': b'goodbye'}] + seqio.test_utils.assert_dataset(dataset, expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/common/data/datasets/inference.py b/optformer/common/data/datasets/inference.py new file mode 100644 index 0000000..fed81bf --- /dev/null +++ b/optformer/common/data/datasets/inference.py @@ -0,0 +1,126 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 datasets.""" + +from typing import Any, Generic, Iterable, TypeVar + +import attrs +from optformer.common.data import featurizers +from optformer.common.data.datasets import base +import seqio +import tensorflow as tf + + +def _inference_feature_converter_validator( + instance: Any, + attribute: attrs.Attribute, + value: seqio.FeatureConverter, +) -> None: + del instance, attribute + if value.pack: + raise ValueError(f"Inference should not use packing. Given: {value.pack}") + + +@attrs.define(kw_only=True) +class SeqIOInferenceDatasetFn(base.DatasetFn[tf.data.Dataset]): + """Only meant to be used during inference. + + Performs tokenization + feature conversion on a featurized dataset via the + SeqIO Task API, with all other kwargs optimized for inference. + """ + + # SeqIO's Feature contains the actual vocabulary. The string keys should match + # the source dataset's keys. + output_features: dict[str, seqio.Feature] = attrs.field(init=True) + + # For final conversion to e.g. T5X model inputs. Normally the T5X model + # already contains a feature converter, but it's only used by training + # pipelines and not at all in the model's inference API (i.e. `predict_batch`) + # so we need to feature-convert the data ourselves. + feature_converter: seqio.FeatureConverter = attrs.field( + init=True, validator=_inference_feature_converter_validator + ) + + # Length of output tensors. The FeatureConverter should add '0' paddings if + # the input data is too short. The string keys should match the source + # dataset's keys. + task_feature_lengths: dict[str, int] = attrs.field(init=True) + + def __call__(self, source: tf.data.Dataset) -> tf.data.Dataset: + ds = source + ds = seqio.preprocessors.tokenize(ds, self.output_features) + ds = seqio.preprocessors.append_eos_after_trim(ds, self.output_features) + ds = seqio.trim_dataset(ds, self.task_feature_lengths, self.output_features) + ds = self.feature_converter(ds, self.task_feature_lengths) + return ds + + +_S = TypeVar("_S") + + +# TODO: Should this just be merged w/ SeqIOInferenceDatasetFn? +@attrs.define(init=False) +class T5XInferenceDatasetFn(Generic[_S], base.DatasetFn[Iterable[_S]]): + """Converts a batch of Python objects into a T5X Model input for inference. + + Python objects must be featurized first. + """ + + featurizer: featurizers.Featurizer[_S] = attrs.field(kw_only=True) + tokenizer_and_converter: SeqIOInferenceDatasetFn = attrs.field(kw_only=True) + + def __init__( + self, + featurizer: featurizers.Featurizer[_S], + input_vocabulary: seqio.Vocabulary, + output_vocabulary: seqio.Vocabulary, + feature_converter: seqio.feature_converters.FeatureConverter, + max_encoder_sequence_length: int, + max_decoder_sequence_length: int, + ): + """Custom init to reduce field ownership and align w/ T5X gin usage.""" + output_features = { + # Input format is already rigorously defined, no need for EOS. + "inputs": seqio.Feature(vocabulary=input_vocabulary, add_eos=False), + # We control the decoding length ourselves, no need for EOS. + "targets": seqio.Feature(vocabulary=output_vocabulary, add_eos=False), + } + tokenizer_and_converter = SeqIOInferenceDatasetFn( + output_features=output_features, + feature_converter=feature_converter, + task_feature_lengths={ + "inputs": max_encoder_sequence_length, + "targets": max_decoder_sequence_length, + }, + ) + self.__attrs_init__( + featurizer=featurizer, tokenizer_and_converter=tokenizer_and_converter + ) + + def __call__(self, source: Iterable[_S]) -> tf.data.Dataset: + """Featurizes + tokenizes objects from a buffer. + + Args: + source: Ideally a live buffer which should not be deleted. + + Returns: + tf.Dataset holding a reference to the generator / buffer. + """ + generator_dataset = tf.data.Dataset.from_generator( + lambda: (self.featurizer.to_features(s) for s in source), + output_types=self.featurizer.output_types, + output_shapes=self.featurizer.output_shapes, + ) + return self.tokenizer_and_converter(generator_dataset) diff --git a/optformer/common/data/datasets/inference_test.py b/optformer/common/data/datasets/inference_test.py new file mode 100644 index 0000000..6dee665 --- /dev/null +++ b/optformer/common/data/datasets/inference_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.data import featurizers +from optformer.common.data import vocabs +from optformer.common.data.datasets import inference +import seqio +import tensorflow as tf + +from absl.testing import absltest + + +class SeqIOInferenceDatasetTest(absltest.TestCase): + + def setUp(self): + super().setUp() + + self.vocab = vocabs.AsciiVocab() + self.raw_data = [{"inputs": "hi", "targets": "bye"}] + self.dataset = tf.data.Dataset.from_generator( + lambda: self.raw_data, + output_types={"inputs": tf.string, "targets": tf.string}, + output_shapes={"inputs": [], "targets": []}, + ) + + self.output_features = { + "inputs": seqio.Feature(vocabulary=self.vocab), + "targets": seqio.Feature(vocabulary=self.vocab), + } + + self.task_feature_lengths = {"inputs": 6, "targets": 6} + + def test_dataset_default(self): + dataset_fn = inference.SeqIOInferenceDatasetFn( + output_features=self.output_features, + feature_converter=seqio.feature_converters.EncDecFeatureConverter( + pack=False + ), + task_feature_lengths=self.task_feature_lengths, + ) + + dataset = dataset_fn(self.dataset) + expected = [{ + "decoder_input_tokens": [0, 98, 121, 101, 1, 0], # Ends w/ EOS=1 + "decoder_loss_weights": [1, 1, 1, 1, 0, 0], + "decoder_target_tokens": [98, 121, 101, 1, 0, 0], + "encoder_input_tokens": [104, 105, 1, 0, 0, 0], + }] + seqio.test_utils.assert_dataset(dataset, expected) + + @absltest.skip("Need to disable FeatureConverter validator to run.") + def test_dataset_with_pack(self): + # Normally `pack` shouldn't be used during inference. We only show this test + # to warn the user about what happens if we do use packing. + + dataset_fn = inference.SeqIOInferenceDatasetFn( + output_features=self.output_features, + feature_converter=seqio.feature_converters.EncDecFeatureConverter( + pack=True + ), + task_feature_lengths=self.task_feature_lengths, + ) + + dataset = dataset_fn(self.dataset) + expected = [{ + "decoder_input_tokens": [0, 98, 121, 101, 0, 0], # No EOS. + "decoder_loss_weights": [1, 1, 1, 1, 0, 0], + "decoder_positions": [0, 1, 2, 3, 0, 0], + "decoder_segment_ids": [1, 1, 1, 1, 0, 0], + "decoder_target_tokens": [98, 121, 101, 1, 0, 0], + "encoder_input_tokens": [104, 105, 1, 0, 0, 0], + "encoder_positions": [0, 1, 2, 0, 0, 0], + "encoder_segment_ids": [1, 1, 1, 0, 0, 0], + }] + seqio.test_utils.assert_dataset(dataset, expected) + + +class T5XInferenceDatasetTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.test_vocab = vocabs.AsciiVocab() + + def test_e2e(self): + inference_dataset_fn = inference.T5XInferenceDatasetFn( + featurizer=featurizers.IdentityFeaturizer(), + input_vocabulary=self.test_vocab, + output_vocabulary=self.test_vocab, + feature_converter=seqio.EncDecFeatureConverter(pack=False), + max_encoder_sequence_length=1024, + max_decoder_sequence_length=1024, + ) + buffer = [] + buffer_dataset = inference_dataset_fn(buffer) + + with self.assertRaises(Exception): + # Can't iterate an empty buffer. + next(buffer_dataset.as_numpy_iterator()) + + buffer.append("hello") + np_iterator = buffer_dataset.as_numpy_iterator() + next(np_iterator) + + +if __name__ == "__main__": + absltest.main() diff --git a/optformer/common/data/datasets/shuffling.py b/optformer/common/data/datasets/shuffling.py new file mode 100644 index 0000000..3f4e4fc --- /dev/null +++ b/optformer/common/data/datasets/shuffling.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shuffling-related dataset functions.""" + +from typing import Optional + +import attrs +from optformer.common.data.datasets import base +import tensorflow as tf + + +@attrs.define +class ShuffleDatasetFn(base.DatasetFn[tf.data.Dataset]): + """Customized shuffling API. Might get modified over time.""" + + # NOTE: Choose to be high enough to simulate IID sampling (when stuck with + # streaming data) but low enough it doesn't blow up RAM. + buffer_size: int = attrs.field(default=1000, kw_only=True) + seed: Optional[int] = attrs.field(default=None, kw_only=True) + + def __call__(self, source: tf.data.Dataset) -> tf.data.Dataset: + return source.shuffle(self.buffer_size, seed=self.seed) diff --git a/optformer/common/data/datasets/wrappers.py b/optformer/common/data/datasets/wrappers.py new file mode 100644 index 0000000..2de7d15 --- /dev/null +++ b/optformer/common/data/datasets/wrappers.py @@ -0,0 +1,49 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper classes which use our `DatasetFn` protocol to interact w/ other types.""" + +from typing import Optional, Sequence + +import attrs +from optformer.common.data.datasets import base +import seqio +import tensorflow as tf + + +@attrs.define +class SeqioDatasetFnFunctor: + """Applies our own `DatasetFn` protocols over SeqIO's `DatasetFnCallable`. + + Same as a 'functor' in functional programming terminology. + """ + + dataset_fns: Sequence[base.DatasetFn[tf.data.Dataset]] = attrs.field() + + def __call__( + self, seqio_dataset_fn: seqio.DatasetFnCallable + ) -> seqio.DatasetFnCallable: + """Returns new SeqIO `DatasetFnCallable` after applying our dataset maps.""" + + def new_dataset_fn( + split: str, + shuffle_files: bool, + seed: Optional[int] = None, + ) -> tf.data.Dataset: + dataset = seqio_dataset_fn(split, shuffle_files, seed) + for dataset_fn in self.dataset_fns: + dataset = dataset_fn(dataset) + return dataset + + return new_dataset_fn diff --git a/optformer/common/data/featurizers/__init__.py b/optformer/common/data/featurizers/__init__.py new file mode 100644 index 0000000..cbc41e1 --- /dev/null +++ b/optformer/common/data/featurizers/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public featurizer imports.""" + +from optformer.common.data.featurizers.base import Featurizer +from optformer.common.data.featurizers.testing import IdentityFeaturizer diff --git a/optformer/common/data/featurizers/base.py b/optformer/common/data/featurizers/base.py new file mode 100644 index 0000000..00aef98 --- /dev/null +++ b/optformer/common/data/featurizers/base.py @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Featurizers for creating TensorDicts, eventually to be used in data pipelines.""" +import abc +from typing import Dict, Generic, TypeVar + +import tensorflow as tf + + +_T = TypeVar('_T') + + +class Featurizer(Generic[_T], abc.ABC): + """Converts an object (ex: study) into features. + + `to_features()` always returns a dictionary consistent with `output_shapes` + and `output_types`, regardless of the input. + """ + + @abc.abstractmethod + def to_features(self, obj: _T, /) -> Dict[str, tf.Tensor]: + """Returns features and raises ValueError in case of failure.""" + + @property + @abc.abstractmethod + def output_types(self) -> Dict[str, tf.DType]: + """Returns the dtypes of values returned by to_features().""" + + @property + @abc.abstractmethod + def output_shapes(self) -> Dict[str, tf.TensorShape]: + """Returns the shapes of values returned by to_features().""" + + @property + @abc.abstractmethod + def empty_output(self) -> Dict[str, tf.Tensor]: + """Empty output to use in case an error is raised. + + Example use: + featurizer: Featurizer[tf.Tensor] + def map_fn(entry: tf.Tensor) -> tuple[bool, Dict[str, tf.Tensor]]: + try: + return True, featurizer.to_features(entry) + except Exception: + return False, featurizer.empty_output + dataset.map(featurizer.to_features).filter(lambda x: x[0]) + """ diff --git a/optformer/common/data/featurizers/testing.py b/optformer/common/data/featurizers/testing.py new file mode 100644 index 0000000..ac1b5b0 --- /dev/null +++ b/optformer/common/data/featurizers/testing.py @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Useful featurizers for testing.""" + +import functools +from optformer.common.data.featurizers import base +import tensorflow as tf + + +class IdentityFeaturizer(base.Featurizer[str]): + """Simply returns identity of bytes string input.""" + + def to_features(self, obj: str, /) -> dict[str, tf.Tensor]: + return { + 'inputs': tf.constant(obj, dtype=tf.string), + 'targets': tf.constant(obj, dtype=tf.string), + } + + @functools.cached_property + def output_types(self) -> dict[str, tf.DType]: + return { + 'inputs': tf.string, + 'targets': tf.string, + } + + @functools.cached_property + def output_shapes(self) -> dict[str, tf.TensorShape]: + return { + 'inputs': tf.TensorShape([]), + 'targets': tf.TensorShape([]), + } + + @functools.cached_property + def empty_output(self) -> dict[str, tf.Tensor]: + return { + 'inputs': tf.constant('', dtype=tf.string), + 'targets': tf.constant('', dtype=tf.string), + } diff --git a/optformer/common/data/filters/__init__.py b/optformer/common/data/filters/__init__.py new file mode 100644 index 0000000..9c80952 --- /dev/null +++ b/optformer/common/data/filters/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public filter imports.""" + +from optformer.common.data.filters.base import Filter +from optformer.common.data.filters.features import TokenLengthFilter diff --git a/optformer/common/data/filters/base.py b/optformer/common/data/filters/base.py new file mode 100644 index 0000000..c8ce27e --- /dev/null +++ b/optformer/common/data/filters/base.py @@ -0,0 +1,39 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Filters which entirely reject the object.""" + +import abc +from typing import Generic, TypeVar + +_T = TypeVar('_T') + + +class Filter(Generic[_T], abc.ABC): + """Filter abstraction.""" + + def __call__(self, obj: _T, /) -> bool: + """Filters an object. + + Args: + obj: Object to be filtered. + + Returns: + True if the object is useful. + + Raises: + ValueError: Instead of returning False, optionally + raise an Error to improve logging at the cost of + performance. + """ diff --git a/optformer/common/data/filters/features.py b/optformer/common/data/filters/features.py new file mode 100644 index 0000000..aa1f20b --- /dev/null +++ b/optformer/common/data/filters/features.py @@ -0,0 +1,48 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ways to filter features.""" + +from typing import Dict + +import attrs +from optformer.common.data.filters import base +import seqio +import tensorflow as tf + + +@attrs.define +class TokenLengthFilter(base.Filter[Dict[str, tf.Tensor]]): + """Check if the string is below a certain token length given a vocabulary. + + If no vocabulary is provided, the raw character length is checked. + """ + + max_token_lengths: Dict[str, int] = attrs.field( + factory=lambda: {'inputs': 4096, 'targets': 4096} + ) + vocab: seqio.Vocabulary | None = attrs.field( + init=True, default=None, kw_only=True + ) + + def __call__(self, features: Dict[str, tf.Tensor]) -> bool: + for k, v in self.max_token_lengths.items(): + if self.vocab: + f_length = len(self.vocab.encode(features[k].numpy())) + else: + f_length = len(features[k].numpy()) + if f_length > v: + raise ValueError(f'Feature {k} has length {f_length} > {v}') + + return True diff --git a/optformer/common/data/filters/features_test.py b/optformer/common/data/filters/features_test.py new file mode 100644 index 0000000..c4679ee --- /dev/null +++ b/optformer/common/data/filters/features_test.py @@ -0,0 +1,41 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.data.filters import features +import tensorflow as tf +from absl.testing import absltest + + +class StringLengthFilterTest(absltest.TestCase): + + def test_e2e(self): + filt = features.TokenLengthFilter( + max_token_lengths={'inputs': 10, 'targets': 10} + ) + good_features = { + 'inputs': tf.constant('hello', dtype=tf.string), + 'targets': tf.constant('world', dtype=tf.string), + } + self.assertTrue(filt(good_features)) + + bad_features = { + 'inputs': tf.constant('tooooooooooooooooooooo', dtype=tf.string), + 'targets': tf.constant('looooooooooooooooooong', dtype=tf.string), + } + with self.assertRaises(ValueError): + filt(bad_features) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/common/data/processors/__init__.py b/optformer/common/data/processors/__init__.py new file mode 100644 index 0000000..1c0840c --- /dev/null +++ b/optformer/common/data/processors/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Public imports of Processors.""" + +from optformer.common.data.processors.base import Processor +from optformer.common.data.processors.masking import BetweenDelimitersMask +from optformer.common.data.processors.masking import ValueMask +from optformer.common.data.processors.partitioning import Partitioner diff --git a/optformer/common/data/processors/base.py b/optformer/common/data/processors/base.py new file mode 100644 index 0000000..9eb8fee --- /dev/null +++ b/optformer/common/data/processors/base.py @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for processing features.""" + +from typing import Protocol, TypeVar +import tensorflow as tf + + +_T = TypeVar('_T') + + +class Processor(Protocol[_T]): + """Atomic method for processing tensors. + + The output type should be compatible with `tf.data.Dataset.map()`, e.g. + 1. tf.Tensor + 2. Mapping[str, tf.Tensor] + 3. Tuple[tf.Tensor, ...] + + There are utility functions such as @seqio.map_over_dataset to automatically + convert Processors to dataset mappers. + """ + + def __call__(self, features: tf.Tensor) -> _T: + """Processes the features.""" diff --git a/optformer/common/data/processors/masking.py b/optformer/common/data/processors/masking.py new file mode 100644 index 0000000..f96d0cd --- /dev/null +++ b/optformer/common/data/processors/masking.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Masking-based preprocessors.""" + +from typing import Sequence, Union + +import attrs +from optformer.common.data.processors import base +import tensorflow as tf + + +ValueType = Union[tf.Tensor, int, float, str, bool] + + +@attrs.define +class ValueMask(base.Processor[tf.Tensor]): + """Computes value-matched mask in the matrix. + + If our masked_values are * and |, then: + [*, |, x, y] + + will have mask: + [0, 0, 1, 1]. + """ + + masked_values: Sequence[ValueType] + + def __call__(self, features: tf.Tensor) -> tf.Tensor: + tensor = features + mask = tf.fill(tf.shape(tensor), True) + for v in self.masked_values: + mask = tf.logical_and(mask, tf.not_equal(tensor, v)) + + return mask + + +@attrs.define +class BetweenDelimitersMask(base.Processor[tf.Tensor]): + """Computes the mask for a tensor given two ordered delimiters. + + For example, if our left/right delimiters are '*' and '|', then the following + input: + [*, x, y, z, |, *, |] + + will have mask: + [0, 1, 1, 1, 0, 0, 0]. + """ + + left: ValueType + right: ValueType + + def __call__(self, features: tf.Tensor) -> tf.Tensor: + tensor = features + left_match = tf.cast(tensor == self.left, tf.int32) + right_match = tf.cast(tensor == self.right, tf.int32) + + # Check if count(left) == count(right) + left_count = tf.reduce_sum(left_match, axis=-1) + right_count = tf.reduce_sum(right_match, axis=-1) + tf.debugging.assert_equal(left_count, right_count) + + # If our example tensor is [x, *, y, |], then example outputs are commented: + left_cs = tf.math.cumsum(left_match, axis=-1) # [0, 1, 1, 1] + right_cs = tf.math.cumsum(right_match, axis=-1) # [0, 0, 0, 1] + left_cs_slice = left_cs[..., :-1] # [0, 1, 1] + zeros = tf.zeros(shape=left_cs_slice.shape[:-1] + [1], dtype=tf.int32) + shifted_left_cs = tf.concat([zeros, left_cs_slice], axis=-1) # [0, 0, 1, 1] + mask = shifted_left_cs - right_cs # [0, 0, 1, 0] + + # Check if there are no -1's (from wrong right -> left orderings). + all_ones_and_zeros = tf.reduce_all((mask == 0) | (mask == 1)) + tf.debugging.assert_equal(True, all_ones_and_zeros) + + mask = tf.cast(mask, tf.bool) + return mask diff --git a/optformer/common/data/processors/masking_test.py b/optformer/common/data/processors/masking_test.py new file mode 100644 index 0000000..63ca8d8 --- /dev/null +++ b/optformer/common/data/processors/masking_test.py @@ -0,0 +1,57 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.data.processors import masking +import tensorflow as tf +from absl.testing import absltest +from absl.testing import parameterized + + +class ValueMaskTest(tf.test.TestCase): + + def test_basic(self): + tensor = tf.constant([1, 2, 3]) + preprocessor = masking.ValueMask(masked_values=[2, 3]) + out = preprocessor(tensor) + expected = tf.constant([1, 0, 0]) + self.assertAllEqual(out, expected) + + +class BetweenDelimitersMaskTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + self.preprocessor = masking.BetweenDelimitersMask(left=-1, right=1) + + @parameterized.parameters( + (tf.constant([-1, 2, 2, 1, -1, 1]), tf.constant([0, 1, 1, 0, 0, 0])), + (tf.constant([2, 2, 2, 2, 2, 2]), tf.constant([0, 0, 0, 0, 0, 0])), + ) + def test_basic(self, tensor: tf.Tensor, expected: tf.Tensor): + out = self.preprocessor(tensor) + self.assertAllEqual(out, expected) + + @parameterized.parameters( + (tf.constant([-1]),), + (tf.constant([1, -1]),), + (tf.constant([1]),), + ) + def test_error(self, tensor: tf.Tensor): + """Input must have immediately-matched left-right delimiters.""" + with self.assertRaises(tf.errors.OpError): + self.preprocessor(tensor) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/common/data/processors/partitioning.py b/optformer/common/data/processors/partitioning.py new file mode 100644 index 0000000..038c27c --- /dev/null +++ b/optformer/common/data/processors/partitioning.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processors which partition batches of data.""" + +from typing import Dict, Tuple + +import attrs +import numpy as np +from optformer.common.data.processors import base +import tensorflow as tf + + +@attrs.define +class Partitioner(base.Processor[Dict[str, tf.Tensor]]): + """Partitions a batch of features according to split ratios.""" + + split_ratios: Dict[str, float] = attrs.field(kw_only=True) + + def __attrs_post_init__(self): + if not np.isclose(sum(self.split_ratios.values()), 1.0): + raise ValueError(f'Ratios {self.split_ratios} do not sum to 1.0.') + + def __call__(self, features: tf.Tensor) -> Dict[str, tf.Tensor]: + def slice_fn(s: np.ndarray) -> Tuple[np.ndarray, ...]: + batch_size = s.shape[0] + + ratios = list(self.split_ratios.values()) + split_indices = (batch_size * np.cumsum(ratios)[:-1]).astype(np.int_) + return np.split(s, split_indices, axis=0) + + n_sp = len(self.split_ratios) + slices = tf.numpy_function(slice_fn, [features], Tout=n_sp * [tf.string]) + + return dict(zip(self.split_ratios.keys(), slices, strict=True)) diff --git a/optformer/common/data/processors/partitioning_test.py b/optformer/common/data/processors/partitioning_test.py new file mode 100644 index 0000000..8de5e7f --- /dev/null +++ b/optformer/common/data/processors/partitioning_test.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.data.processors import partitioning +import tensorflow as tf +from absl.testing import absltest + + +class PartitionerTest(absltest.TestCase): + + def test_e2e(self): + partitioner = partitioning.Partitioner( + split_ratios={'train': 0.8, 'validation': 0.1, 'test': 0.1} + ) + + features = tf.constant(list(range(10))) + partitioned_features = partitioner(features) + + self.assertLen(partitioned_features['train'], 8) + self.assertLen(partitioned_features['validation'], 1) + self.assertLen(partitioned_features['test'], 1) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/common/data/vocabs/__init__.py b/optformer/common/data/vocabs/__init__.py new file mode 100644 index 0000000..0dbc902 --- /dev/null +++ b/optformer/common/data/vocabs/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All import vocabularies.""" + +from optformer.common.data.vocabs.ascii import AsciiVocab diff --git a/optformer/common/data/vocabs/ascii.py b/optformer/common/data/vocabs/ascii.py new file mode 100644 index 0000000..4b519ac --- /dev/null +++ b/optformer/common/data/vocabs/ascii.py @@ -0,0 +1,61 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ascii Vocabulary.""" + +from typing import Optional +import seqio +import tensorflow as tf + + +class AsciiVocab(seqio.Vocabulary): + """Copied from seqio/vocabularies_test.py. + + Useful for char-by-char tokenization and testing. + """ + + def __init__( + self, extra_ids: int = 0, use_eos: bool = True, use_unk: bool = True + ): + super().__init__(extra_ids=extra_ids) + self._extra_ids = extra_ids + self._use_eos = use_eos + self._use_unk = use_unk + + @property + def eos_id(self) -> Optional[int]: + return 1 if self._use_eos else None + + @property + def unk_id(self) -> Optional[int]: + return 2 if self._use_unk else None + + @property + def _base_vocab_size(self) -> int: + return 128 + + def _encode(self, s: str) -> list[int]: + return [ord(c) for c in s] + + def _decode(self, ids: list[int]) -> str: + return "".join("" if id == 1 else chr(id) for id in ids if id > 0) + + def _encode_tf(self, s: str) -> tf.Tensor: + return tf.strings.unicode_decode(s, "UTF-8") + + def _decode_tf(self, ids: list[int]) -> tf.Tensor: + s = tf.strings.unicode_encode(ids, "UTF-8") + s = tf.strings.regex_replace(s, chr(0), "") + s = tf.strings.regex_replace(s, chr(1), "") + return s diff --git a/optformer/common/data/vocabs/ascii_test.py b/optformer/common/data/vocabs/ascii_test.py new file mode 100644 index 0000000..566163b --- /dev/null +++ b/optformer/common/data/vocabs/ascii_test.py @@ -0,0 +1,57 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from optformer.common.data.vocabs import ascii as ascii_lib + +from absl.testing import absltest +from absl.testing import parameterized + + +class AsciiTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.vocab = ascii_lib.AsciiVocab() + + @parameterized.parameters( + ('', []), + ('a', [97]), + ('abc', [97, 98, 99]), + ) + def test_encode(self, s: str, e: List[int]): + self.assertEqual(self.vocab.encode(s), e) + + @parameterized.parameters( + ([0], ''), + ([1], ''), + ([97], 'a'), + ([0, 97], 'a'), + ([1, 97], ''), # NOTE: =1 terminates decoding. + ([51, 52, 53], '345'), + ) + def test_decode(self, e: List[int], s: str): + self.assertEqual(self.vocab.decode(e), s) + + @parameterized.parameters(('a',), ('abc',), ('hello',)) + def test_reversibility(self, s: str): + # Reversibility will only occur for non-special (PAD/EOS/UNK) tokens. + encoded = self.vocab.encode(s) + decoded = self.vocab.decode(encoded) + self.assertEqual(decoded, s) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/inference/decoding.py b/optformer/common/inference/decoding.py similarity index 97% rename from optformer/inference/decoding.py rename to optformer/common/inference/decoding.py index c3f046c..b6ac227 100644 --- a/optformer/inference/decoding.py +++ b/optformer/common/inference/decoding.py @@ -19,7 +19,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float, Int # pylint: disable=g-multiple-import,g-importing-member -from optformer.inference import sequence_utils as seq_utils +from optformer.common.inference import sequence_utils as seq_utils from t5x import decoding diff --git a/optformer/inference/sequence_utils.py b/optformer/common/inference/sequence_utils.py similarity index 100% rename from optformer/inference/sequence_utils.py rename to optformer/common/inference/sequence_utils.py diff --git a/optformer/inference/sequence_utils_test.py b/optformer/common/inference/sequence_utils_test.py similarity index 99% rename from optformer/inference/sequence_utils_test.py rename to optformer/common/inference/sequence_utils_test.py index 9694031..575ad58 100644 --- a/optformer/inference/sequence_utils_test.py +++ b/optformer/common/inference/sequence_utils_test.py @@ -16,7 +16,7 @@ from jax.experimental import checkify import jax.numpy as jnp import numpy as np -from optformer.inference import sequence_utils +from optformer.common.inference import sequence_utils from optformer.validation import checkify as _checkify from absl.testing import absltest from absl.testing import parameterized diff --git a/optformer/common/serialization/__init__.py b/optformer/common/serialization/__init__.py new file mode 100644 index 0000000..c02a4a3 --- /dev/null +++ b/optformer/common/serialization/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entryway to common serializers.""" + +from optformer.common.serialization.base import Deserializer +from optformer.common.serialization.base import Serializer +from optformer.common.serialization.base import SerializerFactory +from optformer.common.serialization.tokens import IntegerTokenSerializer +from optformer.common.serialization.tokens import OneToManyTokenSerializer +from optformer.common.serialization.tokens import StringTokenSerializer +from optformer.common.serialization.tokens import TokenSerializer +from optformer.common.serialization.tokens import UnitSequenceTokenSerializer +from optformer.common.serialization.tokens import UnitTokenSerializer diff --git a/optformer/serialization/base.py b/optformer/common/serialization/base.py similarity index 100% rename from optformer/serialization/base.py rename to optformer/common/serialization/base.py diff --git a/optformer/common/serialization/numeric/__init__.py b/optformer/common/serialization/numeric/__init__.py new file mode 100644 index 0000000..230c06e --- /dev/null +++ b/optformer/common/serialization/numeric/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All numeric-related serialization.""" + +from optformer.common.serialization.numeric.tokens import DigitByDigitFloatTokenSerializer diff --git a/optformer/common/serialization/numeric/tokens.py b/optformer/common/serialization/numeric/tokens.py new file mode 100644 index 0000000..cdb0e23 --- /dev/null +++ b/optformer/common/serialization/numeric/tokens.py @@ -0,0 +1,130 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General float serializers using dedicated tokens.""" + +import re +from typing import Sequence, Union + +import attrs +import gin +import numpy as np +from optformer.common.serialization import tokens as tokens_lib +import ordered_set + + +@gin.configurable +@attrs.define +class DigitByDigitFloatTokenSerializer( + tokens_lib.CartesianProductTokenSerializer[float] +): + """Serializes floats digit-by-digit using dedicated tokens. + + NOTE: It was experimentally verified this was the best serialization method. + + A float f can be represented as: + + `s * m * 10^e` + + where: + s: Positive/Negative sign (+, -) + m: Mantissa representing leading digits. + e: Exponent. + + Attributes: + num_digits: Number of digits in `m`. Each digit (even the leading) is + between <0> and <9>. + exponent_range: Controls number of exponent tokens, e.g. if 10, the exponent + token range will be [, ], affecting the range of representable + floats. + """ + + num_digits: int = attrs.field(default=3) + exponent_range: int = attrs.field(default=10) + + tokens_serializer: tokens_lib.TokenSerializer[Sequence[Union[str, int]]] = ( + attrs.field( + kw_only=True, + factory=tokens_lib.UnitSequenceTokenSerializer, + ) + ) + + @property + def num_tokens_per_obj(self) -> int: + return 2 + self.num_digits + + def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]: + if index < 0 or index >= self.num_tokens_per_obj: + raise ValueError(f'Index {index} out of bounds.') + + if index == 0: # beginning + return ordered_set.OrderedSet(self.tokens_serializer.to_str(['+', '-'])) + + elif index == self.num_tokens_per_obj - 1: # end + exps = [ + f'E{i}' for i in range(-self.exponent_range, self.exponent_range + 1) + ] + return ordered_set.OrderedSet(self.tokens_serializer.to_str(exps)) + else: # middle (digit) + digits = list(range(0, 10)) + return ordered_set.OrderedSet(self.tokens_serializer.to_str(digits)) + + @property + def _max_abs_val(self) -> float: + """Largest representable positive number.""" + return float(self.num_digits * '9') * (10.0**self.exponent_range) + + @property + def _min_abs_val(self) -> float: + """Smallest representable positive number.""" + min_mantissa = float('1' + (self.num_digits - 1) * '0') + return min_mantissa * (10 ** (-self.exponent_range)) + + def _round_float(self, f: float) -> float: + """Rounds float to the closest in-range value.""" + abs_f = abs(f) + abs_f = min(abs_f, self._max_abs_val) + if abs_f < self._min_abs_val: + # Decides whether to move to 0.0 or `min_abs_val`. + zero_or_min = round(abs_f / self._min_abs_val) + abs_f = self._min_abs_val * zero_or_min + return abs_f if f >= 0 else -abs_f + + def to_str(self, f: float, /) -> str: + f = self._round_float(f) + s = np.format_float_scientific( + f, + precision=self.num_digits - 1, + min_digits=self.num_digits - 1, + sign=True, + ) + # We expect numpy to produce scientific notation of the form `+2.123e+4`. + # It will round for us and ensure leading digit isn't zero, unless the + # number is zero. + m = re.fullmatch('([+-])([0-9.]*)e(.*)', s) + if not m: + raise RuntimeError(f'Unexpected numpy notation: {s}') + sign: str = m.group(1) + digits = list(m.group(2).replace('.', '')) + exp = int(m.group(3)) - len(digits) + 1 if f else 0 + return self.tokens_serializer.to_str([sign] + digits + [f'E{exp}']) + + def from_str(self, s: str, /) -> float: + tokens = self.tokens_serializer.from_str(s) + + sign = -1 if tokens[0] == '-' else 1 + mantissa = int(''.join(map(str, tokens[1:-1]))) + exp = int(''.join(tokens[-1]).lstrip('E')) + + return float(sign * mantissa * 10**exp) diff --git a/optformer/common/serialization/numeric/tokens_test.py b/optformer/common/serialization/numeric/tokens_test.py new file mode 100644 index 0000000..624dd8e --- /dev/null +++ b/optformer/common/serialization/numeric/tokens_test.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.common.serialization.numeric import tokens +from absl.testing import absltest +from absl.testing import parameterized + + +class DigitByDigitFloatTokenSerializerTest(parameterized.TestCase): + + @parameterized.parameters( + (123.4, '<+><1><2><3>', 123.0), + (12345, '<+><1><2><3>', 12300), + (0.1234, '<+><1><2><3>', 0.123), + (-123.4, '<-><1><2><3>', -123.0), + (-12345, '<-><1><2><3>', -12300), + (-0.1234, '<-><1><2><3>', -0.123), + (0.0, '<+><0><0><0>', 0.0), + (-0.0, '<+><0><0><0>', 0.0), # in python, 0.0 == -0.0 + (-0.4e-13, '<-><0><0><0>', 0.0), # notice the leading negative zero + ) + def test_serialize(self, f: float, serialized: str, deserialized: float): + serializer = tokens.DigitByDigitFloatTokenSerializer() + self.assertEqual(serializer.to_str(f), serialized) + self.assertEqual(serializer.from_str(serialized), deserialized) + + @parameterized.parameters((3, 10, 1.0e-8, 9.99e12), (1, 5, 1.0e-5, 9.0e5)) + def test_representation_range( + self, + num_digits: int, + exponent_range: int, + min_val: float, + max_val: float, + ): + serializer = tokens.DigitByDigitFloatTokenSerializer( + num_digits=num_digits, + exponent_range=exponent_range, + ) + self.assertEqual(serializer._max_abs_val, max_val) + self.assertEqual(serializer._min_abs_val, min_val) + + @parameterized.parameters( + (1.0e13, 3, 10, '<+><9><9><9>'), + (2.0e13, 3, 10, '<+><9><9><9>'), + (-1.0e13, 3, 10, '<-><9><9><9>'), + (-2.0e13, 3, 10, '<-><9><9><9>'), + (9.9e12, 3, 10, '<+><9><9><0>'), + (-9.9e12, 3, 10, '<-><9><9><0>'), + (5.0e5, 3, 10, '<+><5><0><0>'), + (1.1e-8, 3, 10, '<+><1><1><0>'), + (0.9e-8, 3, 10, '<+><1><0><0>'), + (0.5e-8, 3, 10, '<+><0><0><0>'), + (0.51e-8, 3, 10, '<+><1><0><0>'), + (0.4e-8, 3, 10, '<+><0><0><0>'), + # rounding up below creats a negative sign for 0 + (-0.4e-8, 3, 10, '<-><0><0><0>'), + (-0.5e-8, 3, 10, '<-><0><0><0>'), + (-0.51e-8, 3, 10, '<-><1><0><0>'), + (-0.8e-8, 3, 10, '<-><1><0><0>'), + (-1.1e-8, 3, 10, '<-><1><1><0>'), + ) + def test_tokenization_limit( + self, + f: float, + num_digits: int, + exponent_range: int, + serialized: str, + ): + serializer = tokens.DigitByDigitFloatTokenSerializer( + num_digits=num_digits, + exponent_range=exponent_range, + ) + self.assertEqual(serializer.to_str(f), serialized) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/serialization/tokens.py b/optformer/common/serialization/tokens.py similarity index 99% rename from optformer/serialization/tokens.py rename to optformer/common/serialization/tokens.py index 98ec375..4604559 100644 --- a/optformer/serialization/tokens.py +++ b/optformer/common/serialization/tokens.py @@ -19,7 +19,7 @@ from typing import Any, Generic, Sequence, Tuple, Type, TypeVar import attrs -from optformer.serialization import base +from optformer.common.serialization import base from optformer.validation import runtime import ordered_set diff --git a/optformer/serialization/tokens_test.py b/optformer/common/serialization/tokens_test.py similarity index 98% rename from optformer/serialization/tokens_test.py rename to optformer/common/serialization/tokens_test.py index 10d9748..d1a7532 100644 --- a/optformer/serialization/tokens_test.py +++ b/optformer/common/serialization/tokens_test.py @@ -14,7 +14,7 @@ from typing import Any, Sequence -from optformer.serialization import tokens +from optformer.common.serialization import tokens from absl.testing import absltest from absl.testing import parameterized diff --git a/optformer/serialization/__init__.py b/optformer/serialization/__init__.py deleted file mode 100644 index 8fa5d08..0000000 --- a/optformer/serialization/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Entryway to common serializers.""" - -from optformer.serialization.base import Deserializer -from optformer.serialization.base import Serializer -from optformer.serialization.base import SerializerFactory -from optformer.serialization.tokens import IntegerTokenSerializer -from optformer.serialization.tokens import OneToManyTokenSerializer -from optformer.serialization.tokens import StringTokenSerializer -from optformer.serialization.tokens import TokenSerializer -from optformer.serialization.tokens import UnitSequenceTokenSerializer -from optformer.serialization.tokens import UnitTokenSerializer