Skip to content

Commit

Permalink
Modify cross entropy loss to allow dense targets.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 490251703
  • Loading branch information
stheertha authored and fedjax authors committed Nov 22, 2022
1 parent aeab691 commit b47ce2c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions fedjax/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,15 @@ def evaluate_batch(metric: Metric,


def unreduced_cross_entropy_loss(targets: jnp.ndarray,
preds: jnp.ndarray) -> jnp.ndarray:
preds: jnp.ndarray,
is_sparse_targets: bool = True) -> jnp.ndarray:
"""Returns unreduced cross entropy loss."""
num_classes = preds.shape[-1]
log_preds = jax.nn.log_softmax(preds)
one_hot_targets = jax.nn.one_hot(targets, num_classes)
return -jnp.sum(one_hot_targets * log_preds, axis=-1)
if is_sparse_targets:
# If targets is sparse, convert to one hot representation.
num_classes = preds.shape[-1]
targets = jax.nn.one_hot(targets, num_classes)
return -jnp.sum(targets * log_preds, axis=-1)


@dataclasses.dataclass
Expand Down

0 comments on commit b47ce2c

Please sign in to comment.