|
15 | 15 | """Vocabularies for encoding/decoding floats."""
|
16 | 16 |
|
17 | 17 | import abc
|
| 18 | +import itertools |
18 | 19 | import re
|
19 | 20 |
|
20 | 21 | import attrs
|
@@ -98,7 +99,7 @@ def from_int(self, token_ids: list[int]) -> float:
|
98 | 99 |
|
99 | 100 |
|
100 | 101 | @attrs.define
|
101 |
| -class NormalizedVocab: |
| 102 | +class NormalizedVocab(FloatVocab): |
102 | 103 | """Vocab which supports only numbers within [0,1]."""
|
103 | 104 |
|
104 | 105 | base: int = attrs.field(default=2)
|
@@ -135,3 +136,102 @@ def from_int(self, token_ids: list[int]) -> float:
|
135 | 136 | x = np.asarray(token_ids)
|
136 | 137 | coeff = np.power(self.base, -1 * np.arange(1, len(x) + 1), dtype=np.float32)
|
137 | 138 | return float(np.sum(x * coeff))
|
| 139 | + |
| 140 | + |
| 141 | +@attrs.define |
| 142 | +class HammingDistanceVocab(FloatVocab): |
| 143 | + """Based on https://ieeexplore.ieee.org/document/8437644. |
| 144 | +
|
| 145 | + Minimizes magnitude change on the float if the tokenization is corrupted. For |
| 146 | + now only works on base=2. |
| 147 | + """ |
| 148 | + |
| 149 | + base: int = attrs.field(default=2) |
| 150 | + length: int = attrs.field(default=1) |
| 151 | + |
| 152 | + all_binary_sequences = attrs.field(init=False) |
| 153 | + |
| 154 | + def __attrs_post_init__(self): |
| 155 | + if self.base != 2: |
| 156 | + raise ValueError('Only base=2 is supported.') |
| 157 | + |
| 158 | + self.all_binary_sequences = [] |
| 159 | + for i in range(self.length + 1): |
| 160 | + for comb in itertools.combinations(range(self.length), i): |
| 161 | + binary_seq = [0] * self.length |
| 162 | + for c in comb: |
| 163 | + binary_seq[c] = 1 |
| 164 | + self.all_binary_sequences.append(tuple(binary_seq)) |
| 165 | + self.all_binary_sequences.sort(key=lambda seq: (sum(seq), seq)) |
| 166 | + |
| 167 | + @property |
| 168 | + def size(self) -> int: |
| 169 | + return self.base |
| 170 | + |
| 171 | + @property |
| 172 | + def token_length(self) -> int: |
| 173 | + return self.length |
| 174 | + |
| 175 | + def logit_mask(self, index: int): |
| 176 | + del index |
| 177 | + return np.ones(self.size, dtype=bool) |
| 178 | + |
| 179 | + def to_int(self, f: float) -> list[int]: |
| 180 | + if not 0 <= f <= 1: |
| 181 | + raise ValueError(f'f must be between 0 and 1, got {f}') |
| 182 | + |
| 183 | + f_int = int(f * self.base**self.length) |
| 184 | + if f_int == self.base**self.length: |
| 185 | + f_int -= 1 # Adjust for the edge case when f is exactly 1 |
| 186 | + return list(self.all_binary_sequences[f_int]) |
| 187 | + |
| 188 | + def from_int(self, token_ids: list[int]) -> float: |
| 189 | + if len(token_ids) != self.length: |
| 190 | + raise ValueError(f'Length {len(token_ids)} does not match {self.length}.') |
| 191 | + if not all(0 <= tid < self.base for tid in token_ids): |
| 192 | + raise ValueError(f'{token_ids} out of range(0, {self.base})') |
| 193 | + |
| 194 | + ind = self.all_binary_sequences.index(tuple(token_ids)) |
| 195 | + return float(ind) / (self.base**self.length) |
| 196 | + |
| 197 | + |
| 198 | +@attrs.define |
| 199 | +class RepeatingVocab(FloatVocab): |
| 200 | + """Performs error correction by majority voting on decoded tokens.""" |
| 201 | + |
| 202 | + base_vocab: FloatVocab = attrs.field() |
| 203 | + num_repeats: int = attrs.field(default=1) |
| 204 | + |
| 205 | + @property |
| 206 | + def size(self) -> int: |
| 207 | + return self.base_vocab.size |
| 208 | + |
| 209 | + @property |
| 210 | + def token_length(self) -> int: |
| 211 | + return self.base_vocab.token_length * self.num_repeats |
| 212 | + |
| 213 | + def logit_mask(self, index: int): |
| 214 | + true_index = index % self.base_vocab.token_length |
| 215 | + return self.base_vocab.logit_mask(true_index) |
| 216 | + |
| 217 | + def to_int(self, f: float) -> list[int]: |
| 218 | + return self.base_vocab.to_int(f) * self.num_repeats |
| 219 | + |
| 220 | + def from_int(self, token_ids: list[int]) -> float: |
| 221 | + if len(token_ids) != self.token_length: |
| 222 | + raise ValueError( |
| 223 | + f'Expected {self.token_length} tokens, got {len(token_ids)}' |
| 224 | + ) |
| 225 | + |
| 226 | + # Reshape repeats into array. |
| 227 | + tokens = np.array(token_ids).reshape( |
| 228 | + self.num_repeats, self.base_vocab.token_length |
| 229 | + ) |
| 230 | + |
| 231 | + # Perform majority voting on each column (token). |
| 232 | + voted_tokens = np.apply_along_axis( |
| 233 | + lambda x: np.bincount(x).argmax(), axis=0, arr=tokens |
| 234 | + ) |
| 235 | + |
| 236 | + # Convert the voted tokens to a float using the base vocabulary. |
| 237 | + return self.base_vocab.from_int(voted_tokens.tolist()) |
0 commit comments