diff --git a/optformer/decoding_regression/models.py b/optformer/decoding_regression/models.py index 108f8ec..1459b10 100644 --- a/optformer/decoding_regression/models.py +++ b/optformer/decoding_regression/models.py @@ -123,7 +123,7 @@ def decode( current_logits[:, ~self._vocab.logit_mask(i)] = NEG_INF # [B, V] - probs = sp.special.softmax(temperature * current_logits, axis=-1) + probs = sp.special.softmax(current_logits / temperature, axis=-1) # Sample tokens. sampled_ids = vectorized_sample(probs)