From b47ce2c7baef50d3b1cdcbc7ad2546bd0696850e Mon Sep 17 00:00:00 2001 From: Ananda Theertha Suresh Date: Tue, 22 Nov 2022 08:38:45 -0800 Subject: [PATCH] Modify cross entropy loss to allow dense targets. PiperOrigin-RevId: 490251703 --- fedjax/core/metrics.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fedjax/core/metrics.py b/fedjax/core/metrics.py index c0bebdb..c03ca53 100644 --- a/fedjax/core/metrics.py +++ b/fedjax/core/metrics.py @@ -263,12 +263,15 @@ def evaluate_batch(metric: Metric, def unreduced_cross_entropy_loss(targets: jnp.ndarray, - preds: jnp.ndarray) -> jnp.ndarray: + preds: jnp.ndarray, + is_sparse_targets: bool = True) -> jnp.ndarray: """Returns unreduced cross entropy loss.""" - num_classes = preds.shape[-1] log_preds = jax.nn.log_softmax(preds) - one_hot_targets = jax.nn.one_hot(targets, num_classes) - return -jnp.sum(one_hot_targets * log_preds, axis=-1) + if is_sparse_targets: + # If targets is sparse, convert to one hot representation. + num_classes = preds.shape[-1] + targets = jax.nn.one_hot(targets, num_classes) + return -jnp.sum(targets * log_preds, axis=-1) @dataclasses.dataclass