Skip to content

Commit b47ce2c

Browse files
stheerthafedjax authors
authored andcommitted
Modify cross entropy loss to allow dense targets.
PiperOrigin-RevId: 490251703
1 parent aeab691 commit b47ce2c

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

fedjax/core/metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,15 @@ def evaluate_batch(metric: Metric,
263263

264264

265265
def unreduced_cross_entropy_loss(targets: jnp.ndarray,
266-
preds: jnp.ndarray) -> jnp.ndarray:
266+
preds: jnp.ndarray,
267+
is_sparse_targets: bool = True) -> jnp.ndarray:
267268
"""Returns unreduced cross entropy loss."""
268-
num_classes = preds.shape[-1]
269269
log_preds = jax.nn.log_softmax(preds)
270-
one_hot_targets = jax.nn.one_hot(targets, num_classes)
271-
return -jnp.sum(one_hot_targets * log_preds, axis=-1)
270+
if is_sparse_targets:
271+
# If targets is sparse, convert to one hot representation.
272+
num_classes = preds.shape[-1]
273+
targets = jax.nn.one_hot(targets, num_classes)
274+
return -jnp.sum(targets * log_preds, axis=-1)
272275

273276

274277
@dataclasses.dataclass

0 commit comments

Comments
 (0)