Skip to content

Commit f68a4c9

Browse files
xingyousongcopybara-github
authored andcommitted
Add IEEFloatTokenizer
PiperOrigin-RevId: 692753815
1 parent d8cbeb6 commit f68a4c9

File tree

3 files changed

+134
-6
lines changed

3 files changed

+134
-6
lines changed

optformer/common/serialization/numeric/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from optformer.common.serialization.numeric.text import SimpleFloatTextSerializer
2121
from optformer.common.serialization.numeric.text import SimpleScientificFloatTextSerializer
2222
from optformer.common.serialization.numeric.tokens import DigitByDigitFloatTokenSerializer
23+
from optformer.common.serialization.numeric.tokens import IEEEFloatTokenSerializer

optformer/common/serialization/numeric/tokens.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""General float serializers using dedicated tokens."""
1616

17+
import math
1718
import re
1819
from typing import Sequence, Union
1920

@@ -23,6 +24,8 @@
2324
from optformer.common.serialization import tokens as tokens_lib
2425
import ordered_set
2526

27+
TokensSerializer = tokens_lib.TokenSerializer[Sequence[Union[str, int]]]
28+
2629

2730
@gin.configurable
2831
@attrs.define
@@ -53,11 +56,9 @@ class DigitByDigitFloatTokenSerializer(
5356
num_digits: int = attrs.field(default=4)
5457
exponent_range: int = attrs.field(default=10)
5558

56-
tokens_serializer: tokens_lib.TokenSerializer[Sequence[Union[str, int]]] = (
57-
attrs.field(
58-
kw_only=True,
59-
factory=tokens_lib.UnitSequenceTokenSerializer,
60-
)
59+
tokens_serializer: TokensSerializer = attrs.field(
60+
kw_only=True,
61+
factory=tokens_lib.UnitSequenceTokenSerializer,
6162
)
6263

6364
@property
@@ -127,3 +128,95 @@ def from_str(self, s: str, /) -> float:
127128
exp = int(''.join(tokens[-1]).lstrip('E'))
128129

129130
return float(sign * mantissa * 10**exp)
131+
132+
133+
@attrs.define(kw_only=True)
134+
class IEEEFloatTokenSerializer(
135+
tokens_lib.CartesianProductTokenSerializer[float]
136+
):
137+
"""More official float serializer, minimizing the use of dedicated tokens.
138+
139+
Follows IEEE-type standard.
140+
141+
A float f = `s * b^e * m` can be represented as [s, e, m] from most to least
142+
important, where:
143+
s: Positive/Negative sign (+, -)
144+
b: Base
145+
e: Exponent (left-most is a sign, digits represented with base b)
146+
m: Mantissa (represented with base b)
147+
148+
For example, 1.23456789e-222 can be represented as:
149+
150+
<+><-><2><2><2><1><2><3><4>
151+
152+
if b=10, num_exponent_digits=3, and num_mantissa_digits=4.
153+
"""
154+
155+
base: int = attrs.field(default=10)
156+
157+
num_exponent_digits: int = attrs.field(default=1)
158+
num_mantissa_digits: int = attrs.field(default=4)
159+
160+
tokens_serializer: TokensSerializer = attrs.field(
161+
factory=tokens_lib.UnitSequenceTokenSerializer,
162+
)
163+
164+
@property
165+
def num_tokens_per_obj(self) -> int:
166+
return 2 + self.num_exponent_digits + self.num_mantissa_digits
167+
168+
def tokens_used(self, index: int) -> ordered_set.OrderedSet[str]:
169+
if index < 0 or index >= self.num_tokens_per_obj:
170+
raise ValueError(f'Index {index} out of bounds.')
171+
172+
if index in [0, 1]: # beginning
173+
tokens = [self.tokens_serializer.to_str([s]) for s in ['+', '-']]
174+
else: # middle (digit)
175+
tokens = [self.tokens_serializer.to_str([s]) for s in range(self.base)]
176+
return ordered_set.OrderedSet(tokens)
177+
178+
def to_str(self, f: float, /) -> str:
179+
sign = '+' if f >= 0 else '-'
180+
abs_f = abs(f)
181+
exponent = math.floor(np.log(abs_f) / np.log(self.base)) if abs_f > 0 else 0
182+
183+
exponent_sign = '+' if exponent >= 0 else '-'
184+
abs_exponent = abs(exponent)
185+
186+
e = np.base_repr(abs_exponent, base=self.base)
187+
if len(e) > self.num_exponent_digits: # Overflow, raise error for now.
188+
# TODO: Should we round or add 'inf' token?
189+
raise ValueError(f'Exponent {e} too large.')
190+
e = e.zfill(self.num_exponent_digits)
191+
192+
mantissa = np.base_repr(
193+
abs_f * self.base ** (self.num_mantissa_digits - 1 - exponent),
194+
base=self.base,
195+
)
196+
197+
if len(mantissa) > self.num_mantissa_digits:
198+
mantissa = mantissa[: self.num_mantissa_digits]
199+
if len(mantissa) < self.num_mantissa_digits: # Right-pad with zeros.
200+
mantissa += '0' * (self.num_mantissa_digits - len(mantissa))
201+
202+
raw_str = sign + exponent_sign + e + mantissa
203+
return self.tokens_serializer.to_str(list(raw_str))
204+
205+
def from_str(self, s: str, /) -> float:
206+
tokens = self.tokens_serializer.from_str(s)
207+
208+
sign = -1 if tokens[0] == '-' else 1
209+
210+
exponent_sign = -1 if tokens[1] == '-' else 1
211+
212+
abs_exponent_str = ''.join(
213+
map(str, tokens[2 : 2 + self.num_exponent_digits])
214+
)
215+
abs_exponent = int(abs_exponent_str, base=self.base)
216+
exponent = exponent_sign * abs_exponent
217+
218+
mantissa_str = ''.join(map(str, tokens[2 + self.num_exponent_digits :]))
219+
mantissa_unscaled = int(mantissa_str, base=self.base)
220+
mantissa = mantissa_unscaled / self.base ** (self.num_mantissa_digits - 1)
221+
222+
return sign * (self.base**exponent) * mantissa

optformer/common/serialization/numeric/tokens_test.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,44 @@ def test_all_tokens_used(self):
8888
out = serializer.all_tokens_used()
8989

9090
signs = ['<+>', '<->']
91-
digits = [f'<{i}>' for i in range(0, 10)]
91+
digits = [f'<{i}>' for i in range(10)]
9292
exponents = ['<E-2>', '<E-1>', '<E0>', '<E1>', '<E2>']
9393
self.assertEqual(list(out), signs + digits + exponents)
9494

9595

96+
class IEEEFloatTokenSerializerTest(parameterized.TestCase):
97+
98+
@parameterized.parameters(
99+
(123.4, '<+><+><2><1><2><3><4>', 123.4),
100+
(12345, '<+><+><4><1><2><3><4>', 12340),
101+
(0.1234, '<+><-><1><1><2><3><4>', 0.1234),
102+
(-123.4, '<-><+><2><1><2><3><4>', -123.4),
103+
(-12345, '<-><+><4><1><2><3><4>', -12340),
104+
(-0.1234, '<-><-><1><1><2><3><4>', -0.1234),
105+
(1.234e-9, '<+><-><9><1><2><3><4>', 1.234e-9),
106+
(-1.234e-9, '<-><-><9><1><2><3><4>', 1.234e-9),
107+
(1.2e-9, '<+><-><9><1><2><0><0>', 1.2e-9),
108+
(-1.2e-9, '<-><-><9><1><2><0><0>', -1.2e-9),
109+
(1.2e9, '<+><+><9><1><2><0><0>', 1.2e9),
110+
(-1.2e9, '<-><+><9><1><2><0><0>', -1.2e9),
111+
(1.2345e9, '<+><+><9><1><2><3><4>', 1.234e9),
112+
(0.0, '<+><+><0><0><0><0><0>', 0.0),
113+
(-0.0, '<+><+><0><0><0><0><0>', 0.0), # in python, 0.0 == -0.0
114+
)
115+
def test_serialize(self, f: float, serialized: str, deserialized: float):
116+
serializer = tokens.IEEEFloatTokenSerializer()
117+
self.assertEqual(serializer.to_str(f), serialized)
118+
self.assertAlmostEqual(serializer.from_str(serialized), deserialized)
119+
120+
@parameterized.parameters((3,), (10,), (18,))
121+
def test_all_tokens_used(self, base: int):
122+
serializer = tokens.IEEEFloatTokenSerializer(base=base)
123+
out = serializer.all_tokens_used()
124+
125+
signs = ['<+>', '<->']
126+
digits = [f'<{i}>' for i in range(base)]
127+
self.assertEqual(list(out), signs + digits)
128+
129+
96130
if __name__ == '__main__':
97131
absltest.main()

0 commit comments

Comments
 (0)