Skip to content

Commit 6c9409a

Browse files
xingyousongcopybara-github
authored andcommitted
Add more vocabs (Hamming and RepeatingErrorCorrection)
PiperOrigin-RevId: 718413874
1 parent cc6ec66 commit 6c9409a

File tree

1 file changed

+101
-1
lines changed

1 file changed

+101
-1
lines changed

optformer/decoding_regression/vocabs.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Vocabularies for encoding/decoding floats."""
1616

1717
import abc
18+
import itertools
1819
import re
1920

2021
import attrs
@@ -98,7 +99,7 @@ def from_int(self, token_ids: list[int]) -> float:
9899

99100

100101
@attrs.define
101-
class NormalizedVocab:
102+
class NormalizedVocab(FloatVocab):
102103
"""Vocab which supports only numbers within [0,1]."""
103104

104105
base: int = attrs.field(default=2)
@@ -135,3 +136,102 @@ def from_int(self, token_ids: list[int]) -> float:
135136
x = np.asarray(token_ids)
136137
coeff = np.power(self.base, -1 * np.arange(1, len(x) + 1), dtype=np.float32)
137138
return float(np.sum(x * coeff))
139+
140+
141+
@attrs.define
142+
class HammingDistanceVocab(FloatVocab):
143+
"""Based on https://ieeexplore.ieee.org/document/8437644.
144+
145+
Minimizes magnitude change on the float if the tokenization is corrupted. For
146+
now only works on base=2.
147+
"""
148+
149+
base: int = attrs.field(default=2)
150+
length: int = attrs.field(default=1)
151+
152+
all_binary_sequences = attrs.field(init=False)
153+
154+
def __attrs_post_init__(self):
155+
if self.base != 2:
156+
raise ValueError('Only base=2 is supported.')
157+
158+
self.all_binary_sequences = []
159+
for i in range(self.length + 1):
160+
for comb in itertools.combinations(range(self.length), i):
161+
binary_seq = [0] * self.length
162+
for c in comb:
163+
binary_seq[c] = 1
164+
self.all_binary_sequences.append(tuple(binary_seq))
165+
self.all_binary_sequences.sort(key=lambda seq: (sum(seq), seq))
166+
167+
@property
168+
def size(self) -> int:
169+
return self.base
170+
171+
@property
172+
def token_length(self) -> int:
173+
return self.length
174+
175+
def logit_mask(self, index: int):
176+
del index
177+
return np.ones(self.size, dtype=bool)
178+
179+
def to_int(self, f: float) -> list[int]:
180+
if not 0 <= f <= 1:
181+
raise ValueError(f'f must be between 0 and 1, got {f}')
182+
183+
f_int = int(f * self.base**self.length)
184+
if f_int == self.base**self.length:
185+
f_int -= 1 # Adjust for the edge case when f is exactly 1
186+
return list(self.all_binary_sequences[f_int])
187+
188+
def from_int(self, token_ids: list[int]) -> float:
189+
if len(token_ids) != self.length:
190+
raise ValueError(f'Length {len(token_ids)} does not match {self.length}.')
191+
if not all(0 <= tid < self.base for tid in token_ids):
192+
raise ValueError(f'{token_ids} out of range(0, {self.base})')
193+
194+
ind = self.all_binary_sequences.index(tuple(token_ids))
195+
return float(ind) / (self.base**self.length)
196+
197+
198+
@attrs.define
199+
class RepeatingVocab(FloatVocab):
200+
"""Performs error correction by majority voting on decoded tokens."""
201+
202+
base_vocab: FloatVocab = attrs.field()
203+
num_repeats: int = attrs.field(default=1)
204+
205+
@property
206+
def size(self) -> int:
207+
return self.base_vocab.size
208+
209+
@property
210+
def token_length(self) -> int:
211+
return self.base_vocab.token_length * self.num_repeats
212+
213+
def logit_mask(self, index: int):
214+
true_index = index % self.base_vocab.token_length
215+
return self.base_vocab.logit_mask(true_index)
216+
217+
def to_int(self, f: float) -> list[int]:
218+
return self.base_vocab.to_int(f) * self.num_repeats
219+
220+
def from_int(self, token_ids: list[int]) -> float:
221+
if len(token_ids) != self.token_length:
222+
raise ValueError(
223+
f'Expected {self.token_length} tokens, got {len(token_ids)}'
224+
)
225+
226+
# Reshape repeats into array.
227+
tokens = np.array(token_ids).reshape(
228+
self.num_repeats, self.base_vocab.token_length
229+
)
230+
231+
# Perform majority voting on each column (token).
232+
voted_tokens = np.apply_along_axis(
233+
lambda x: np.bincount(x).argmax(), axis=0, arr=tokens
234+
)
235+
236+
# Convert the voted tokens to a float using the base vocabulary.
237+
return self.base_vocab.from_int(voted_tokens.tolist())

0 commit comments

Comments
 (0)