From ac126ffa2e3f2be69842243d7301177776a56eb2 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Fri, 24 Jan 2025 17:56:54 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 719493290 --- optformer/decoding_regression/models.py | 40 ++++++++++---------- optformer/decoding_regression/models_test.py | 8 ++-- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/optformer/decoding_regression/models.py b/optformer/decoding_regression/models.py index 1459b10..4da3826 100644 --- a/optformer/decoding_regression/models.py +++ b/optformer/decoding_regression/models.py @@ -110,6 +110,8 @@ def decode( self, encoder_input: jt.Float[jt.Array, 'B F'], temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = None, ) -> jt.Float[jt.Array, 'B']: """Performs temperature sampling.""" # pylint: disable=invalid-name @@ -122,6 +124,24 @@ def decode( # Logit restriction current_logits[:, ~self._vocab.logit_mask(i)] = NEG_INF + # Apply top-k and top-p filtering + if top_k is not None: + top_k = min(top_k, current_logits.shape[-1]) # prevent index errors + thresholds = np.sort(current_logits, axis=-1)[:, -top_k][:, np.newaxis] + current_logits = np.where( + current_logits < thresholds, NEG_INF, current_logits + ) + + if top_p is not None: + probs = sp.special.softmax(current_logits / temperature, axis=-1) + sorted_probs = np.sort(probs, axis=-1)[:, ::-1] + cumulative_probs = np.cumsum(sorted_probs, axis=-1) + cutoff_indices = np.argmax(cumulative_probs > top_p, axis=-1) + for batch_idx in range(B): + sorted_indices = np.argsort(current_logits[batch_idx, :]) + top_k_indices = sorted_indices[: -cutoff_indices[batch_idx]] + current_logits[batch_idx, top_k_indices] = NEG_INF + # [B, V] probs = sp.special.softmax(current_logits / temperature, axis=-1) @@ -131,26 +151,6 @@ def decode( return np.array([self._vocab.from_int(toks) for toks in token_ids]) - def greedy_decode( - self, encoder_input: jt.Float[jt.Array, 'B F'] - ) -> jt.Float[jt.Array, 'B']: - """Performs greedy decoding.""" - # pylint: disable=invalid-name - B = encoder_input.shape[0] - token_ids = -1 * np.ones((B, self._vocab.token_length), dtype=int) - - for i in range(self._vocab.token_length): - logits = self.predict([encoder_input, token_ids[:, :i]]) - current_logits = logits[:, -1, :] - # Logit restriction - current_logits[:, ~self._vocab.logit_mask(i)] = NEG_INF - - # Pick argmax instead. - sampled_ids = np.argmax(current_logits, axis=-1) # [B] - token_ids[:, i] = np.array(sampled_ids) - - return np.array([self._vocab.from_int(toks) for toks in token_ids]) - def weighted_sparse_categorical_crossentropy(labels, logits, weights=None): """Weighted version of sparse categorical cross entropy.""" diff --git a/optformer/decoding_regression/models_test.py b/optformer/decoding_regression/models_test.py index c23e4ed..c221727 100644 --- a/optformer/decoding_regression/models_test.py +++ b/optformer/decoding_regression/models_test.py @@ -18,13 +18,15 @@ from optformer.decoding_regression import vocabs import tensorflow as tf from absl.testing import absltest +from absl.testing import parameterized keras = tf.keras -class ModelTest(absltest.TestCase): +class ModelTest(parameterized.TestCase): - def test_e2e(self): + @parameterized.parameters((None, None), (5, None), (None, 0.5), (3, 0.1)) + def test_e2e(self, top_k, top_p): # pylint: disable=invalid-name encoder = tf.keras.models.Sequential([]) vocab = vocabs.UnnormalizedVocab() @@ -54,7 +56,7 @@ def test_e2e(self): validation_split=0.2, ) - floats = decoder.decode(X[:10]) + floats = decoder.decode(X[:10], top_k=top_k, top_p=top_p) self.assertLen(floats, 10)