diff --git a/optformer/serialization/__init__.py b/optformer/serialization/__init__.py new file mode 100644 index 0000000..d56436b --- /dev/null +++ b/optformer/serialization/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entryway to common serializers.""" + +from optformer.serialization.base import Deserializer +from optformer.serialization.base import Serializer +from optformer.serialization.base import SerializerFactory +from optformer.serialization.tokens import IntegerTokenSerializer +from optformer.serialization.tokens import OneToManyTokenSerializer +from optformer.serialization.tokens import StringTokenSerializer +from optformer.serialization.tokens import TokenSerializer +from optformer.serialization.tokens import UnitSequenceTokenSerializer +from optformer.serialization.tokens import UnitTokenSerializer diff --git a/optformer/serialization/base.py b/optformer/serialization/base.py new file mode 100644 index 0000000..0135b20 --- /dev/null +++ b/optformer/serialization/base.py @@ -0,0 +1,54 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes for serializers.""" +import abc +from typing import Generic, Optional, Protocol, TypeVar + +_T = TypeVar('_T') + + +class Serializer(abc.ABC, Generic[_T]): + """Base class for stringifying objects. + + Should always have deterministic behavior (i.e. the same input value should + always map to the same output value). + """ + + @abc.abstractmethod + def to_str(self, obj: _T, /) -> str: + """Turns an object to text.""" + + +class SerializerFactory(Protocol[_T]): + """Factory for creating serializers. + + Useful abstraction for simulating random serialization behavior. + """ + + @abc.abstractmethod + def __call__(self, *, seed: Optional[int] = None) -> Serializer[_T]: + """Creates the Serializer from seed.""" + + +class Deserializer(abc.ABC, Generic[_T]): + """Base class for deserializing strings. + + Should always have deterministic behavior (i.e. the same input value should + always map to the same output value). + """ + + @abc.abstractmethod + def from_str(self, s: str, /) -> _T: + """Turns the string back into the object.""" diff --git a/optformer/serialization/tokens.py b/optformer/serialization/tokens.py new file mode 100644 index 0000000..1de64f4 --- /dev/null +++ b/optformer/serialization/tokens.py @@ -0,0 +1,182 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for mappings between tokens and objects.""" + +import abc +import re +from typing import Any, Generic, Sequence, Tuple, Type, TypeVar + +import attrs +from optformer.serialization import base +from optformer.validation import runtime +import ordered_set + +_V = TypeVar('_V') + + +# TODO: Allow for different forward/backward types. +class TokenSerializer(base.Serializer[_V], base.Deserializer[_V]): + """Base class for mapping an object to custom tokens.""" + + DELIMITERS: Tuple[str, str] = ('<', '>') + + +class UnitTokenSerializer(TokenSerializer[_V]): + """Bijective mapping between single object and single token.""" + + def to_str(self, obj: _V) -> str: + left_d, right_d = self.DELIMITERS + return f'{left_d}{obj}{right_d}' + + def from_str(self, s: str) -> _V: + left_d, right_d = self.DELIMITERS + pattern = f'{left_d}{self.regex_type}{right_d}' + m = re.fullmatch(pattern, s) + if not m: + raise ValueError(f'Input string {s} is not a valid token.') + return self.type(m.group(1)) + + @property + @abc.abstractmethod + def regex_type(self) -> str: + """Regex type used for deserialization.""" + + @property + @abc.abstractmethod + def type(self) -> Type[_V]: + """Type of the token value, used for deserialization.""" + + +class IntegerTokenSerializer(UnitTokenSerializer[int]): + + @property + def regex_type(self) -> str: + return '([-+]?\\d+)' + + @property + def type(self) -> Type[int]: + return int + + +class StringTokenSerializer(UnitTokenSerializer[str]): + + @property + def regex_type(self) -> str: + return '(.*?)' + + @property + def type(self) -> Type[str]: + return str + + +@attrs.define +class UnitSequenceTokenSerializer(Generic[_V], TokenSerializer[Sequence[_V]]): + """Bijective mapping between sequence of objects to sequence of tokens. + + Uses type-specific tokenizers with ordered priority to handle every sequence + object. + + By default, handles integers and strings, e.g. [42, 'x', -1] -> '<42><-1>'. + """ + + token_serializers: Sequence[UnitTokenSerializer[_V]] = attrs.field( + factory=lambda: [IntegerTokenSerializer(), StringTokenSerializer()] + ) + + def to_str(self, obj: Sequence[Any], /) -> str: + """Performs string conversion on decoder-type inputs.""" + out = [] + for o in obj: + for token_serializer in self.token_serializers: + if isinstance(o, token_serializer.type): + out.append(token_serializer.to_str(o)) + break + else: + raise ValueError(f'Type {type(o)} is not supported.') + + return ''.join(out) + + def from_str(self, s: str, /) -> Sequence[Any]: + left_d, right_d = self.DELIMITERS + pattern = re.compile(f'{left_d}(.*?){right_d}') + matches = pattern.finditer(s) + + # Makes best effort to use single tokenizers to deserialize match. + single_values = [] + for match in matches: + for token_serializer in self.token_serializers: + s = f'{left_d}{match.group(1)}{right_d}' + try: + v = token_serializer.from_str(s) + single_values.append(v) + break + except ValueError: + # TODO: Make dedicated `SerializationError`. + pass + else: + raise ValueError(f'Could not deserialize `{s}`.') + + return single_values + + +class OneToManyTokenSerializer(TokenSerializer[_V]): + """Maps one object to many (fixed count) tokens.""" + + @property + @abc.abstractmethod + def num_tokens_per_obj(self) -> int: + """Number of tokens used to represent each object.""" + + +@attrs.define +class RepeatedUnitTokenSerializer(OneToManyTokenSerializer[_V]): + """Simply outputs repeats of a unit token.""" + + unit_token_serializer: UnitTokenSerializer[_V] = attrs.field() + num_tokens_per_obj = attrs.field() + + def to_str(self, obj: _V) -> str: + return self.num_tokens_per_obj * self.unit_token_serializer.to_str(obj) + + def from_str(self, s: str) -> _V: + left_d, right_d = self.DELIMITERS + pattern = re.compile(f'{left_d}(.*?){right_d}') + matches = pattern.finditer(s) + inner_strs = [match.group(1) for match in matches] + + runtime.assert_all_elements_same(inner_strs) + + s = f'{left_d}{inner_strs[0]}{right_d}' + return self.num_tokens_per_obj * self.unit_token_serializer.from_str(s) + + +# TODO: Use this to refactor `ScientificFloatTokenSerializer`. +class CartesianProductTokenSerializer(OneToManyTokenSerializer[Sequence[_V]]): + """Maps an object to a fixed number of tokens based on cartesian product. + + Output will be of form e.g. ... where is from set A, is from + set , is from set , etc. + """ + + def all_tokens_used(self) -> ordered_set.OrderedSet[str]: + """Returns ordered set of all tokens used.""" + out = [] + for i in range(self.num_tokens_per_obj): + out.extend(self.tokens_used(i)) + return ordered_set.OrderedSet(out) + + @abc.abstractmethod + def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]: + """Returns ordered set of tokens used at position `index`.""" diff --git a/optformer/serialization/tokens_test.py b/optformer/serialization/tokens_test.py new file mode 100644 index 0000000..91f16e6 --- /dev/null +++ b/optformer/serialization/tokens_test.py @@ -0,0 +1,144 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Sequence + +from optformer.serialization import tokens + +from absl.testing import absltest +from absl.testing import parameterized + + +class IntegerTokenTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.serializer = tokens.IntegerTokenSerializer() + + @parameterized.parameters( + (42, '<42>'), + (0, '<0>'), + (-3, '<-3>'), + ) + def test_serialize(self, x: int, expected: str): + self.assertEqual(self.serializer.to_str(x), expected) + + @parameterized.parameters( + ('<42>', 42), + ('<0>', 0), + ('<-3>', -3), + ) + def test_deserialize(self, y: str, expected: int): + self.assertEqual(self.serializer.from_str(y), expected) + + @parameterized.parameters( + ('<42',), + ('42>',), + ('',), + ) + def test_deserialize_error(self, y: str): + with self.assertRaises(ValueError): + self.serializer.from_str(y) + + @parameterized.parameters( + (242,), + (0,), + (-356,), + ) + def test_reversibility(self, x): + y = self.serializer.to_str(x) + self.assertEqual(self.serializer.from_str(y), x) + + +class StringTokenTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.serializer = tokens.StringTokenSerializer() + + @parameterized.parameters( + ('hi', ''), + ('', '<>'), + ('3', '<3>'), + ) + def test_serialize(self, x: int, expected: str): + self.assertEqual(self.serializer.to_str(x), expected) + + @parameterized.parameters( + ('', 'im spaced'), + ('<>', ''), + ('<-3>', '-3'), + ) + def test_deserialize(self, y: str, expected: int): + self.assertEqual(self.serializer.from_str(y), expected) + + @parameterized.parameters( + ('',), + ('hi',), + ) + def test_deserialize_error(self, y: str): + with self.assertRaises(ValueError): + self.serializer.from_str(y) + + @parameterized.parameters( + ('',), + ('<>',), + ('<-3>',), + ) + def test_reversibility(self, x): + y = self.serializer.to_str(x) + self.assertEqual(self.serializer.from_str(y), x) + + +class UnitSequenceTokenTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.serializer = tokens.UnitSequenceTokenSerializer() + + @parameterized.parameters( + ([0, 42, 'm'], '<0><42>'), + ([5], '<5>'), + (['m'], ''), + ([42], '<42>'), + ([42, 1], '<42><1>'), + ([], ''), + ) + def test_serialization(self, obj: Sequence[Any], output: str): + self.assertEqual(self.serializer.to_str(obj), output) + + @parameterized.parameters( + ('<0><42>', [0, 42, 'm']), + ('<-5>', [-5]), + ('', ['m']), + ('<42>', [42]), + ('<42><1>', [42, 1]), + ('', []), + ) + def test_deserialization(self, s: str, obj: Sequence[Any]): + self.assertEqual(self.serializer.from_str(s), obj) + + @parameterized.parameters( + ([0, 242, 'm', -1],), + ([0],), + ([],), + ) + def test_reversibility(self, x: Sequence[Any]): + y = self.serializer.to_str(x) + self.assertEqual(self.serializer.from_str(y), x) + + +if __name__ == '__main__': + absltest.main() diff --git a/optformer/validation/runtime.py b/optformer/validation/runtime.py new file mode 100644 index 0000000..48766ef --- /dev/null +++ b/optformer/validation/runtime.py @@ -0,0 +1,55 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common validator logic for runtime / output checking.""" + +from typing import Any, Sequence, Tuple, Union +import numpy as np + +FloatLike = Union[float, np.ndarray] +IntLike = Union[int, np.ndarray] +IntervalType = Tuple[float, float] + + +def assert_in_interval(interval: IntervalType, x: FloatLike) -> None: + """Checks if all values in x are within the interval.""" + low, high = interval + if not np.logical_and(x >= low, x <= high).all(): + raise ValueError(f"Input {x} out of bounds from [{low}, {high}].") + + +def assert_is_int_like(x: IntLike) -> None: + """Checks if array is type int.""" + if isinstance(x, np.ndarray) and x.dtype not in [np.int32, np.int64]: + raise ValueError(f"Input {x} has non integer type {x.dtype}.") + + +def assert_length(x: Sequence[Any], length: int) -> None: + if len(x) != length: + raise ValueError(f"Sequence length {len(x)} != expected {length}.") + + +def assert_all_elements_same(x: Sequence[Any]) -> None: + """Checks if all elements in x are the same. + + Args: + x: a sequence of elements + + NOTE: Be careful about checking a sequence of mutable objects. + """ + if not x: + return + + if not all(y == x[0] for y in x): + raise ValueError(f"Not all elements in {x} are the same") diff --git a/optformer/validation/runtime_test.py b/optformer/validation/runtime_test.py new file mode 100644 index 0000000..220a7b9 --- /dev/null +++ b/optformer/validation/runtime_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for runtime.py.""" + +from typing import Any, Sequence + +from optformer.validation import runtime + +from absl.testing import absltest +from absl.testing import parameterized + + +class RuntimeTest(parameterized.TestCase): + + @parameterized.parameters( + ([], True), + ([1], True), + ([1, 2], False), + ([1, 1, 1], True), + (('a', 'a'), True), + (('a', 'b', 'c'), False), + ) + def test_all_elements_same(self, x: Sequence[Any], is_same: bool): + if is_same: + runtime.assert_all_elements_same(x) + else: + with self.assertRaises(ValueError): # pylint:disable=g-error-prone-assert-raises + runtime.assert_all_elements_same(x) + + +if __name__ == '__main__': + absltest.main()