File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed
Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -263,12 +263,15 @@ def evaluate_batch(metric: Metric,
263263
264264
265265def unreduced_cross_entropy_loss (targets : jnp .ndarray ,
266- preds : jnp .ndarray ) -> jnp .ndarray :
266+ preds : jnp .ndarray ,
267+ is_sparse_targets : bool = True ) -> jnp .ndarray :
267268 """Returns unreduced cross entropy loss."""
268- num_classes = preds .shape [- 1 ]
269269 log_preds = jax .nn .log_softmax (preds )
270- one_hot_targets = jax .nn .one_hot (targets , num_classes )
271- return - jnp .sum (one_hot_targets * log_preds , axis = - 1 )
270+ if is_sparse_targets :
271+ # If targets is sparse, convert to one hot representation.
272+ num_classes = preds .shape [- 1 ]
273+ targets = jax .nn .one_hot (targets , num_classes )
274+ return - jnp .sum (targets * log_preds , axis = - 1 )
272275
273276
274277@dataclasses .dataclass
You can’t perform that action at this time.
0 commit comments