Skip to content

Commit

Permalink
Internal only
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715646813
  • Loading branch information
xingyousong authored and copybara-github committed Jan 31, 2025
1 parent c2825a8 commit cc6ec66
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions optformer/decoding_regression/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,30 +97,41 @@ def from_int(self, token_ids: list[int]) -> float:
return self.serializer.from_str(str_f)


# TODO: Add support for more bases.
@attrs.define
class NormalizedVocab(FloatVocab):
class NormalizedVocab:
"""Vocab which supports only numbers within [0,1]."""

num_bits: int = attrs.field(default=1)
base: int = attrs.field(default=2)
length: int = attrs.field(default=1)

@property
def size(self) -> int:
return 2
return self.base

@property
def token_length(self) -> int:
return self.num_bits
return self.length

def logit_mask(self, index: int):
del index
return np.ones(self.size, dtype=bool)

def to_int(self, f: float) -> list[int]:
assert f >= 0 and f <= 1
f_trunc = int(f * 2.0**self.num_bits)
f_trunc = bin(f_trunc)[2:].zfill(self.num_bits)
if not 0 <= f <= 1:
raise ValueError(f'f must be between 0 and 1, got {f}')

f_trunc = int(f * self.base**self.length)
if f_trunc == self.base**self.length:
f_trunc -= 1 # Adjust for the edge case when f is exactly 1
f_trunc = np.base_repr(f_trunc, base=self.base).zfill(self.length)
return [int(b) for b in f_trunc]

def from_int(self, token_ids: list[int]) -> float:
if len(token_ids) != self.length:
raise ValueError(f'Length {len(token_ids)} does not match {self.length}.')
if not all(0 <= tid < self.base for tid in token_ids):
raise ValueError(f'{token_ids} out of range(0, {self.base})')

x = np.asarray(token_ids)
return np.sum(x * 2.0 ** (-1 * np.arange(1, len(x) + 1)))
coeff = np.power(self.base, -1 * np.arange(1, len(x) + 1), dtype=np.float32)
return float(np.sum(x * coeff))

0 comments on commit cc6ec66

Please sign in to comment.