Skip to content

Commit

Permalink
Add IEEFloatTokenizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692753815
  • Loading branch information
xingyousong authored and copybara-github committed Nov 4, 2024
1 parent d8cbeb6 commit f68a4c9
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 6 deletions.
1 change: 1 addition & 0 deletions optformer/common/serialization/numeric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from optformer.common.serialization.numeric.text import SimpleFloatTextSerializer
from optformer.common.serialization.numeric.text import SimpleScientificFloatTextSerializer
from optformer.common.serialization.numeric.tokens import DigitByDigitFloatTokenSerializer
from optformer.common.serialization.numeric.tokens import IEEEFloatTokenSerializer
103 changes: 98 additions & 5 deletions optformer/common/serialization/numeric/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""General float serializers using dedicated tokens."""

import math
import re
from typing import Sequence, Union

Expand All @@ -23,6 +24,8 @@
from optformer.common.serialization import tokens as tokens_lib
import ordered_set

TokensSerializer = tokens_lib.TokenSerializer[Sequence[Union[str, int]]]


@gin.configurable
@attrs.define
Expand Down Expand Up @@ -53,11 +56,9 @@ class DigitByDigitFloatTokenSerializer(
num_digits: int = attrs.field(default=4)
exponent_range: int = attrs.field(default=10)

tokens_serializer: tokens_lib.TokenSerializer[Sequence[Union[str, int]]] = (
attrs.field(
kw_only=True,
factory=tokens_lib.UnitSequenceTokenSerializer,
)
tokens_serializer: TokensSerializer = attrs.field(
kw_only=True,
factory=tokens_lib.UnitSequenceTokenSerializer,
)

@property
Expand Down Expand Up @@ -127,3 +128,95 @@ def from_str(self, s: str, /) -> float:
exp = int(''.join(tokens[-1]).lstrip('E'))

return float(sign * mantissa * 10**exp)


@attrs.define(kw_only=True)
class IEEEFloatTokenSerializer(
tokens_lib.CartesianProductTokenSerializer[float]
):
"""More official float serializer, minimizing the use of dedicated tokens.
Follows IEEE-type standard.
A float f = `s * b^e * m` can be represented as [s, e, m] from most to least
important, where:
s: Positive/Negative sign (+, -)
b: Base
e: Exponent (left-most is a sign, digits represented with base b)
m: Mantissa (represented with base b)
For example, 1.23456789e-222 can be represented as:
<+><-><2><2><2><1><2><3><4>
if b=10, num_exponent_digits=3, and num_mantissa_digits=4.
"""

base: int = attrs.field(default=10)

num_exponent_digits: int = attrs.field(default=1)
num_mantissa_digits: int = attrs.field(default=4)

tokens_serializer: TokensSerializer = attrs.field(
factory=tokens_lib.UnitSequenceTokenSerializer,
)

@property
def num_tokens_per_obj(self) -> int:
return 2 + self.num_exponent_digits + self.num_mantissa_digits

def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]:
if index < 0 or index >= self.num_tokens_per_obj:
raise ValueError(f'Index {index} out of bounds.')

if index in [0, 1]: # beginning
tokens = [self.tokens_serializer.to_str([s]) for s in ['+', '-']]
else: # middle (digit)
tokens = [self.tokens_serializer.to_str([s]) for s in range(self.base)]
return ordered_set.OrderedSet(tokens)

def to_str(self, f: float, /) -> str:
sign = '+' if f >= 0 else '-'
abs_f = abs(f)
exponent = math.floor(np.log(abs_f) / np.log(self.base)) if abs_f > 0 else 0

exponent_sign = '+' if exponent >= 0 else '-'
abs_exponent = abs(exponent)

e = np.base_repr(abs_exponent, base=self.base)
if len(e) > self.num_exponent_digits: # Overflow, raise error for now.
# TODO: Should we round or add 'inf' token?
raise ValueError(f'Exponent {e} too large.')
e = e.zfill(self.num_exponent_digits)

mantissa = np.base_repr(
abs_f * self.base ** (self.num_mantissa_digits - 1 - exponent),
base=self.base,
)

if len(mantissa) > self.num_mantissa_digits:
mantissa = mantissa[: self.num_mantissa_digits]
if len(mantissa) < self.num_mantissa_digits: # Right-pad with zeros.
mantissa += '0' * (self.num_mantissa_digits - len(mantissa))

raw_str = sign + exponent_sign + e + mantissa
return self.tokens_serializer.to_str(list(raw_str))

def from_str(self, s: str, /) -> float:
tokens = self.tokens_serializer.from_str(s)

sign = -1 if tokens[0] == '-' else 1

exponent_sign = -1 if tokens[1] == '-' else 1

abs_exponent_str = ''.join(
map(str, tokens[2 : 2 + self.num_exponent_digits])
)
abs_exponent = int(abs_exponent_str, base=self.base)
exponent = exponent_sign * abs_exponent

mantissa_str = ''.join(map(str, tokens[2 + self.num_exponent_digits :]))
mantissa_unscaled = int(mantissa_str, base=self.base)
mantissa = mantissa_unscaled / self.base ** (self.num_mantissa_digits - 1)

return sign * (self.base**exponent) * mantissa
36 changes: 35 additions & 1 deletion optformer/common/serialization/numeric/tokens_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,44 @@ def test_all_tokens_used(self):
out = serializer.all_tokens_used()

signs = ['<+>', '<->']
digits = [f'<{i}>' for i in range(0, 10)]
digits = [f'<{i}>' for i in range(10)]
exponents = ['<E-2>', '<E-1>', '<E0>', '<E1>', '<E2>']
self.assertEqual(list(out), signs + digits + exponents)


class IEEEFloatTokenSerializerTest(parameterized.TestCase):

@parameterized.parameters(
(123.4, '<+><+><2><1><2><3><4>', 123.4),
(12345, '<+><+><4><1><2><3><4>', 12340),
(0.1234, '<+><-><1><1><2><3><4>', 0.1234),
(-123.4, '<-><+><2><1><2><3><4>', -123.4),
(-12345, '<-><+><4><1><2><3><4>', -12340),
(-0.1234, '<-><-><1><1><2><3><4>', -0.1234),
(1.234e-9, '<+><-><9><1><2><3><4>', 1.234e-9),
(-1.234e-9, '<-><-><9><1><2><3><4>', 1.234e-9),
(1.2e-9, '<+><-><9><1><2><0><0>', 1.2e-9),
(-1.2e-9, '<-><-><9><1><2><0><0>', -1.2e-9),
(1.2e9, '<+><+><9><1><2><0><0>', 1.2e9),
(-1.2e9, '<-><+><9><1><2><0><0>', -1.2e9),
(1.2345e9, '<+><+><9><1><2><3><4>', 1.234e9),
(0.0, '<+><+><0><0><0><0><0>', 0.0),
(-0.0, '<+><+><0><0><0><0><0>', 0.0), # in python, 0.0 == -0.0
)
def test_serialize(self, f: float, serialized: str, deserialized: float):
serializer = tokens.IEEEFloatTokenSerializer()
self.assertEqual(serializer.to_str(f), serialized)
self.assertAlmostEqual(serializer.from_str(serialized), deserialized)

@parameterized.parameters((3,), (10,), (18,))
def test_all_tokens_used(self, base: int):
serializer = tokens.IEEEFloatTokenSerializer(base=base)
out = serializer.all_tokens_used()

signs = ['<+>', '<->']
digits = [f'<{i}>' for i in range(base)]
self.assertEqual(list(out), signs + digits)


if __name__ == '__main__':
absltest.main()

0 comments on commit f68a4c9

Please sign in to comment.