Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719493290
  • Loading branch information
xingyousong authored and copybara-github committed Jan 31, 2025
1 parent e98c8fe commit ac126ff
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
40 changes: 20 additions & 20 deletions optformer/decoding_regression/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def decode(
self,
encoder_input: jt.Float[jt.Array, 'B F'],
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> jt.Float[jt.Array, 'B']:
"""Performs temperature sampling."""
# pylint: disable=invalid-name
Expand All @@ -122,6 +124,24 @@ def decode(
# Logit restriction
current_logits[:, ~self._vocab.logit_mask(i)] = NEG_INF

# Apply top-k and top-p filtering
if top_k is not None:
top_k = min(top_k, current_logits.shape[-1]) # prevent index errors
thresholds = np.sort(current_logits, axis=-1)[:, -top_k][:, np.newaxis]
current_logits = np.where(
current_logits < thresholds, NEG_INF, current_logits
)

if top_p is not None:
probs = sp.special.softmax(current_logits / temperature, axis=-1)
sorted_probs = np.sort(probs, axis=-1)[:, ::-1]
cumulative_probs = np.cumsum(sorted_probs, axis=-1)
cutoff_indices = np.argmax(cumulative_probs > top_p, axis=-1)
for batch_idx in range(B):
sorted_indices = np.argsort(current_logits[batch_idx, :])
top_k_indices = sorted_indices[: -cutoff_indices[batch_idx]]
current_logits[batch_idx, top_k_indices] = NEG_INF

# [B, V]
probs = sp.special.softmax(current_logits / temperature, axis=-1)

Expand All @@ -131,26 +151,6 @@ def decode(

return np.array([self._vocab.from_int(toks) for toks in token_ids])

def greedy_decode(
self, encoder_input: jt.Float[jt.Array, 'B F']
) -> jt.Float[jt.Array, 'B']:
"""Performs greedy decoding."""
# pylint: disable=invalid-name
B = encoder_input.shape[0]
token_ids = -1 * np.ones((B, self._vocab.token_length), dtype=int)

for i in range(self._vocab.token_length):
logits = self.predict([encoder_input, token_ids[:, :i]])
current_logits = logits[:, -1, :]
# Logit restriction
current_logits[:, ~self._vocab.logit_mask(i)] = NEG_INF

# Pick argmax instead.
sampled_ids = np.argmax(current_logits, axis=-1) # [B]
token_ids[:, i] = np.array(sampled_ids)

return np.array([self._vocab.from_int(toks) for toks in token_ids])


def weighted_sparse_categorical_crossentropy(labels, logits, weights=None):
"""Weighted version of sparse categorical cross entropy."""
Expand Down
8 changes: 5 additions & 3 deletions optformer/decoding_regression/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from optformer.decoding_regression import vocabs
import tensorflow as tf
from absl.testing import absltest
from absl.testing import parameterized

keras = tf.keras


class ModelTest(absltest.TestCase):
class ModelTest(parameterized.TestCase):

def test_e2e(self):
@parameterized.parameters((None, None), (5, None), (None, 0.5), (3, 0.1))
def test_e2e(self, top_k, top_p):
# pylint: disable=invalid-name
encoder = tf.keras.models.Sequential([])
vocab = vocabs.UnnormalizedVocab()
Expand Down Expand Up @@ -54,7 +56,7 @@ def test_e2e(self):
validation_split=0.2,
)

floats = decoder.decode(X[:10])
floats = decoder.decode(X[:10], top_k=top_k, top_p=top_p)
self.assertLen(floats, 10)


Expand Down

0 comments on commit ac126ff

Please sign in to comment.