diff --git a/optformer/common/serialization/numeric/__init__.py b/optformer/common/serialization/numeric/__init__.py index f686c91..c08dc61 100644 --- a/optformer/common/serialization/numeric/__init__.py +++ b/optformer/common/serialization/numeric/__init__.py @@ -20,3 +20,4 @@ from optformer.common.serialization.numeric.text import SimpleFloatTextSerializer from optformer.common.serialization.numeric.text import SimpleScientificFloatTextSerializer from optformer.common.serialization.numeric.tokens import DigitByDigitFloatTokenSerializer +from optformer.common.serialization.numeric.tokens import IEEEFloatTokenSerializer diff --git a/optformer/common/serialization/numeric/tokens.py b/optformer/common/serialization/numeric/tokens.py index 51e1b5a..bdd4b3e 100644 --- a/optformer/common/serialization/numeric/tokens.py +++ b/optformer/common/serialization/numeric/tokens.py @@ -14,6 +14,7 @@ """General float serializers using dedicated tokens.""" +import math import re from typing import Sequence, Union @@ -23,6 +24,8 @@ from optformer.common.serialization import tokens as tokens_lib import ordered_set +TokensSerializer = tokens_lib.TokenSerializer[Sequence[Union[str, int]]] + @gin.configurable @attrs.define @@ -53,11 +56,9 @@ class DigitByDigitFloatTokenSerializer( num_digits: int = attrs.field(default=4) exponent_range: int = attrs.field(default=10) - tokens_serializer: tokens_lib.TokenSerializer[Sequence[Union[str, int]]] = ( - attrs.field( - kw_only=True, - factory=tokens_lib.UnitSequenceTokenSerializer, - ) + tokens_serializer: TokensSerializer = attrs.field( + kw_only=True, + factory=tokens_lib.UnitSequenceTokenSerializer, ) @property @@ -127,3 +128,95 @@ def from_str(self, s: str, /) -> float: exp = int(''.join(tokens[-1]).lstrip('E')) return float(sign * mantissa * 10**exp) + + +@attrs.define(kw_only=True) +class IEEEFloatTokenSerializer( + tokens_lib.CartesianProductTokenSerializer[float] +): + """More official float serializer, minimizing the use of dedicated tokens. + + Follows IEEE-type standard. + + A float f = `s * b^e * m` can be represented as [s, e, m] from most to least + important, where: + s: Positive/Negative sign (+, -) + b: Base + e: Exponent (left-most is a sign, digits represented with base b) + m: Mantissa (represented with base b) + + For example, 1.23456789e-222 can be represented as: + + <+><-><2><2><2><1><2><3><4> + + if b=10, num_exponent_digits=3, and num_mantissa_digits=4. + """ + + base: int = attrs.field(default=10) + + num_exponent_digits: int = attrs.field(default=1) + num_mantissa_digits: int = attrs.field(default=4) + + tokens_serializer: TokensSerializer = attrs.field( + factory=tokens_lib.UnitSequenceTokenSerializer, + ) + + @property + def num_tokens_per_obj(self) -> int: + return 2 + self.num_exponent_digits + self.num_mantissa_digits + + def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]: + if index < 0 or index >= self.num_tokens_per_obj: + raise ValueError(f'Index {index} out of bounds.') + + if index in [0, 1]: # beginning + tokens = [self.tokens_serializer.to_str([s]) for s in ['+', '-']] + else: # middle (digit) + tokens = [self.tokens_serializer.to_str([s]) for s in range(self.base)] + return ordered_set.OrderedSet(tokens) + + def to_str(self, f: float, /) -> str: + sign = '+' if f >= 0 else '-' + abs_f = abs(f) + exponent = math.floor(np.log(abs_f) / np.log(self.base)) if abs_f > 0 else 0 + + exponent_sign = '+' if exponent >= 0 else '-' + abs_exponent = abs(exponent) + + e = np.base_repr(abs_exponent, base=self.base) + if len(e) > self.num_exponent_digits: # Overflow, raise error for now. + # TODO: Should we round or add 'inf' token? + raise ValueError(f'Exponent {e} too large.') + e = e.zfill(self.num_exponent_digits) + + mantissa = np.base_repr( + abs_f * self.base ** (self.num_mantissa_digits - 1 - exponent), + base=self.base, + ) + + if len(mantissa) > self.num_mantissa_digits: + mantissa = mantissa[: self.num_mantissa_digits] + if len(mantissa) < self.num_mantissa_digits: # Right-pad with zeros. + mantissa += '0' * (self.num_mantissa_digits - len(mantissa)) + + raw_str = sign + exponent_sign + e + mantissa + return self.tokens_serializer.to_str(list(raw_str)) + + def from_str(self, s: str, /) -> float: + tokens = self.tokens_serializer.from_str(s) + + sign = -1 if tokens[0] == '-' else 1 + + exponent_sign = -1 if tokens[1] == '-' else 1 + + abs_exponent_str = ''.join( + map(str, tokens[2 : 2 + self.num_exponent_digits]) + ) + abs_exponent = int(abs_exponent_str, base=self.base) + exponent = exponent_sign * abs_exponent + + mantissa_str = ''.join(map(str, tokens[2 + self.num_exponent_digits :])) + mantissa_unscaled = int(mantissa_str, base=self.base) + mantissa = mantissa_unscaled / self.base ** (self.num_mantissa_digits - 1) + + return sign * (self.base**exponent) * mantissa diff --git a/optformer/common/serialization/numeric/tokens_test.py b/optformer/common/serialization/numeric/tokens_test.py index 709a5c5..db5b818 100644 --- a/optformer/common/serialization/numeric/tokens_test.py +++ b/optformer/common/serialization/numeric/tokens_test.py @@ -88,10 +88,44 @@ def test_all_tokens_used(self): out = serializer.all_tokens_used() signs = ['<+>', '<->'] - digits = [f'<{i}>' for i in range(0, 10)] + digits = [f'<{i}>' for i in range(10)] exponents = ['', '', '', '', ''] self.assertEqual(list(out), signs + digits + exponents) +class IEEEFloatTokenSerializerTest(parameterized.TestCase): + + @parameterized.parameters( + (123.4, '<+><+><2><1><2><3><4>', 123.4), + (12345, '<+><+><4><1><2><3><4>', 12340), + (0.1234, '<+><-><1><1><2><3><4>', 0.1234), + (-123.4, '<-><+><2><1><2><3><4>', -123.4), + (-12345, '<-><+><4><1><2><3><4>', -12340), + (-0.1234, '<-><-><1><1><2><3><4>', -0.1234), + (1.234e-9, '<+><-><9><1><2><3><4>', 1.234e-9), + (-1.234e-9, '<-><-><9><1><2><3><4>', 1.234e-9), + (1.2e-9, '<+><-><9><1><2><0><0>', 1.2e-9), + (-1.2e-9, '<-><-><9><1><2><0><0>', -1.2e-9), + (1.2e9, '<+><+><9><1><2><0><0>', 1.2e9), + (-1.2e9, '<-><+><9><1><2><0><0>', -1.2e9), + (1.2345e9, '<+><+><9><1><2><3><4>', 1.234e9), + (0.0, '<+><+><0><0><0><0><0>', 0.0), + (-0.0, '<+><+><0><0><0><0><0>', 0.0), # in python, 0.0 == -0.0 + ) + def test_serialize(self, f: float, serialized: str, deserialized: float): + serializer = tokens.IEEEFloatTokenSerializer() + self.assertEqual(serializer.to_str(f), serialized) + self.assertAlmostEqual(serializer.from_str(serialized), deserialized) + + @parameterized.parameters((3,), (10,), (18,)) + def test_all_tokens_used(self, base: int): + serializer = tokens.IEEEFloatTokenSerializer(base=base) + out = serializer.all_tokens_used() + + signs = ['<+>', '<->'] + digits = [f'<{i}>' for i in range(base)] + self.assertEqual(list(out), signs + digits) + + if __name__ == '__main__': absltest.main()