Skip to content

Commit d7d87e3

Browse files
authored
Fix KL when num_classes != 2 (#1820)
Fix to [issue 1724](#1724)
1 parent 6685fb8 commit d7d87e3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

adversarial_text/adversarial_losses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
212212

213213
# For softmax regression
214214
else:
215+
q = tf.nn.softmax(q_logits)
215216
kl = tf.reduce_sum(
216217
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
217218

0 commit comments

Comments
 (0)