Skip to content

Commit

Permalink
1. Move base + token serialization to OSS.
Browse files Browse the repository at this point in the history
2. Move runtime validation to OSS.

PiperOrigin-RevId: 596595013
  • Loading branch information
xingyousong authored and copybara-github committed Jan 8, 2024
1 parent a573ce8 commit 880b533
Show file tree
Hide file tree
Showing 6 changed files with 504 additions and 0 deletions.
25 changes: 25 additions & 0 deletions optformer/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Entryway to common serializers."""

from optformer.serialization.base import Deserializer
from optformer.serialization.base import Serializer
from optformer.serialization.base import SerializerFactory
from optformer.serialization.tokens import IntegerTokenSerializer
from optformer.serialization.tokens import OneToManyTokenSerializer
from optformer.serialization.tokens import StringTokenSerializer
from optformer.serialization.tokens import TokenSerializer
from optformer.serialization.tokens import UnitSequenceTokenSerializer
from optformer.serialization.tokens import UnitTokenSerializer
54 changes: 54 additions & 0 deletions optformer/serialization/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base classes for serializers."""
import abc
from typing import Generic, Optional, Protocol, TypeVar

_T = TypeVar('_T')


class Serializer(abc.ABC, Generic[_T]):
"""Base class for stringifying objects.
Should always have deterministic behavior (i.e. the same input value should
always map to the same output value).
"""

@abc.abstractmethod
def to_str(self, obj: _T, /) -> str:
"""Turns an object to text."""


class SerializerFactory(Protocol[_T]):
"""Factory for creating serializers.
Useful abstraction for simulating random serialization behavior.
"""

@abc.abstractmethod
def __call__(self, *, seed: Optional[int] = None) -> Serializer[_T]:
"""Creates the Serializer from seed."""


class Deserializer(abc.ABC, Generic[_T]):
"""Base class for deserializing strings.
Should always have deterministic behavior (i.e. the same input value should
always map to the same output value).
"""

@abc.abstractmethod
def from_str(self, s: str, /) -> _T:
"""Turns the string back into the object."""
182 changes: 182 additions & 0 deletions optformer/serialization/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for mappings between tokens and objects."""

import abc
import re
from typing import Any, Generic, Sequence, Tuple, Type, TypeVar

import attrs
from optformer.serialization import base
from optformer.validation import runtime
import ordered_set

_V = TypeVar('_V')


# TODO: Allow for different forward/backward types.
class TokenSerializer(base.Serializer[_V], base.Deserializer[_V]):
"""Base class for mapping an object to custom tokens."""

DELIMITERS: Tuple[str, str] = ('<', '>')


class UnitTokenSerializer(TokenSerializer[_V]):
"""Bijective mapping between single object and single token."""

def to_str(self, obj: _V) -> str:
left_d, right_d = self.DELIMITERS
return f'{left_d}{obj}{right_d}'

def from_str(self, s: str) -> _V:
left_d, right_d = self.DELIMITERS
pattern = f'{left_d}{self.regex_type}{right_d}'
m = re.fullmatch(pattern, s)
if not m:
raise ValueError(f'Input string {s} is not a valid token.')
return self.type(m.group(1))

@property
@abc.abstractmethod
def regex_type(self) -> str:
"""Regex type used for deserialization."""

@property
@abc.abstractmethod
def type(self) -> Type[_V]:
"""Type of the token value, used for deserialization."""


class IntegerTokenSerializer(UnitTokenSerializer[int]):

@property
def regex_type(self) -> str:
return '([-+]?\\d+)'

@property
def type(self) -> Type[int]:
return int


class StringTokenSerializer(UnitTokenSerializer[str]):

@property
def regex_type(self) -> str:
return '(.*?)'

@property
def type(self) -> Type[str]:
return str


@attrs.define
class UnitSequenceTokenSerializer(Generic[_V], TokenSerializer[Sequence[_V]]):
"""Bijective mapping between sequence of objects to sequence of tokens.
Uses type-specific tokenizers with ordered priority to handle every sequence
object.
By default, handles integers and strings, e.g. [42, 'x', -1] -> '<42><x><-1>'.
"""

token_serializers: Sequence[UnitTokenSerializer[_V]] = attrs.field(
factory=lambda: [IntegerTokenSerializer(), StringTokenSerializer()]
)

def to_str(self, obj: Sequence[Any], /) -> str:
"""Performs string conversion on decoder-type inputs."""
out = []
for o in obj:
for token_serializer in self.token_serializers:
if isinstance(o, token_serializer.type):
out.append(token_serializer.to_str(o))
break
else:
raise ValueError(f'Type {type(o)} is not supported.')

return ''.join(out)

def from_str(self, s: str, /) -> Sequence[Any]:
left_d, right_d = self.DELIMITERS
pattern = re.compile(f'{left_d}(.*?){right_d}')
matches = pattern.finditer(s)

# Makes best effort to use single tokenizers to deserialize match.
single_values = []
for match in matches:
for token_serializer in self.token_serializers:
s = f'{left_d}{match.group(1)}{right_d}'
try:
v = token_serializer.from_str(s)
single_values.append(v)
break
except ValueError:
# TODO: Make dedicated `SerializationError`.
pass
else:
raise ValueError(f'Could not deserialize `{s}`.')

return single_values


class OneToManyTokenSerializer(TokenSerializer[_V]):
"""Maps one object to many (fixed count) tokens."""

@property
@abc.abstractmethod
def num_tokens_per_obj(self) -> int:
"""Number of tokens used to represent each object."""


@attrs.define
class RepeatedUnitTokenSerializer(OneToManyTokenSerializer[_V]):
"""Simply outputs repeats of a unit token."""

unit_token_serializer: UnitTokenSerializer[_V] = attrs.field()
num_tokens_per_obj = attrs.field()

def to_str(self, obj: _V) -> str:
return self.num_tokens_per_obj * self.unit_token_serializer.to_str(obj)

def from_str(self, s: str) -> _V:
left_d, right_d = self.DELIMITERS
pattern = re.compile(f'{left_d}(.*?){right_d}')
matches = pattern.finditer(s)
inner_strs = [match.group(1) for match in matches]

runtime.assert_all_elements_same(inner_strs)

s = f'{left_d}{inner_strs[0]}{right_d}'
return self.num_tokens_per_obj * self.unit_token_serializer.from_str(s)


# TODO: Use this to refactor `ScientificFloatTokenSerializer`.
class CartesianProductTokenSerializer(OneToManyTokenSerializer[Sequence[_V]]):
"""Maps an object to a fixed number of tokens based on cartesian product.
Output will be of form e.g. <a><b><c>... where <a> is from set A, <b> is from
set <B>, <c> is from set <C>, etc.
"""

def all_tokens_used(self) -> ordered_set.OrderedSet[str]:
"""Returns ordered set of all tokens used."""
out = []
for i in range(self.num_tokens_per_obj):
out.extend(self.tokens_used(i))
return ordered_set.OrderedSet(out)

@abc.abstractmethod
def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]:
"""Returns ordered set of tokens used at position `index`."""
Loading

0 comments on commit 880b533

Please sign in to comment.