File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -263,12 +263,15 @@ def evaluate_batch(metric: Metric,
263
263
264
264
265
265
def unreduced_cross_entropy_loss (targets : jnp .ndarray ,
266
- preds : jnp .ndarray ) -> jnp .ndarray :
266
+ preds : jnp .ndarray ,
267
+ is_sparse_targets : bool = True ) -> jnp .ndarray :
267
268
"""Returns unreduced cross entropy loss."""
268
- num_classes = preds .shape [- 1 ]
269
269
log_preds = jax .nn .log_softmax (preds )
270
- one_hot_targets = jax .nn .one_hot (targets , num_classes )
271
- return - jnp .sum (one_hot_targets * log_preds , axis = - 1 )
270
+ if is_sparse_targets :
271
+ # If targets is sparse, convert to one hot representation.
272
+ num_classes = preds .shape [- 1 ]
273
+ targets = jax .nn .one_hot (targets , num_classes )
274
+ return - jnp .sum (targets * log_preds , axis = - 1 )
272
275
273
276
274
277
@dataclasses .dataclass
You can’t perform that action at this time.
0 commit comments