Skip to content

Commit 3b41cfe

Browse files
run keras compatibility test with v2 behavior. (#1374)
1 parent 480a8ee commit 3b41cfe

File tree

1 file changed

+79
-53
lines changed

1 file changed

+79
-53
lines changed

tensorflow_addons/seq2seq/loss_test.py

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,37 @@
2424
from tensorflow_addons.utils import test_utils
2525

2626

27+
def get_test_data():
28+
batch_size = 2
29+
sequence_length = 3
30+
number_of_classes = 5
31+
logits = [
32+
tf.constant(i + 0.5, shape=[batch_size, number_of_classes])
33+
for i in range(sequence_length)
34+
]
35+
logits = tf.stack(logits, axis=1)
36+
targets = [
37+
tf.constant(i, tf.int32, shape=[batch_size]) for i in range(sequence_length)
38+
]
39+
targets = tf.stack(targets, axis=1)
40+
41+
weights = [tf.constant(1.0, shape=[batch_size]) for _ in range(sequence_length)]
42+
weights = tf.stack(weights, axis=1)
43+
# expected_loss = sparse_softmax_cross_entropy_with_logits(targets,
44+
# logits) where targets = [0, 1, 2],
45+
# and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5]
46+
expected_loss = 1.60944
47+
return (
48+
batch_size,
49+
sequence_length,
50+
number_of_classes,
51+
logits,
52+
targets,
53+
weights,
54+
expected_loss,
55+
)
56+
57+
2758
@test_utils.run_all_in_graph_and_eager_modes
2859
class LossTest(tf.test.TestCase):
2960
def setup(self):
@@ -325,59 +356,54 @@ def testAmbiguousOrder(self):
325356
self.evaluate(seq_loss(self.targets, self.logits, self.weights))
326357

327358

328-
@test_utils.run_all_in_graph_and_eager_modes
329-
class DenseTargetLossTest(LossTest):
330-
def setup(self):
331-
super().setup()
332-
self.targets = tf.one_hot(self.targets, depth=self.number_of_classes)
333-
334-
@pytest.mark.xfail(tf.__version__ == "2.2.0-rc1", reason="TODO: Fix this test")
335-
def testKerasCompatibility(self):
336-
"""To test the compatibility of SequenceLoss with Keras's built-in
337-
training loops, we create a fake model which always outputs a pre-
338-
defined set of logits.
339-
340-
Then we check the calculated loss to be equal to the expected
341-
loss. Note that since the fake model doesn't have any trainable
342-
parameters, no matter how many steps we train it, it always
343-
outputs the same loss value.
344-
"""
345-
with self.cached_session(use_gpu=True):
346-
self.setup()
347-
348-
def return_logits(x):
349-
batch_size = tf.shape(x)[0]
350-
logits_single_row = self.logits[0, :, :]
351-
logits_batch = tf.tile(
352-
tf.expand_dims(logits_single_row, 0), [batch_size, 1, 1]
353-
)
354-
return logits_batch
355-
356-
inp = tf.keras.layers.Input(shape=(self.sequence_length,))
357-
out = tf.keras.layers.Lambda(
358-
return_logits,
359-
output_shape=(self.sequence_length, self.number_of_classes),
360-
)(inp)
361-
model = tf.keras.models.Model(inp, out)
362-
363-
loss_obj = loss.SequenceLoss()
364-
model.compile(
365-
optimizer="adam", loss=loss_obj, sample_weight_mode="temporal"
366-
)
367-
368-
# This is a fake input.
369-
x = tf.ones(shape=(self.batch_size, self.sequence_length))
370-
371-
h = model.fit(
372-
x,
373-
self.targets,
374-
sample_weight=self.weights,
375-
batch_size=self.batch_size,
376-
steps_per_epoch=1,
377-
)
378-
379-
calculated_loss = h.history["loss"][0]
380-
self.assertAllClose(calculated_loss, self.expected_loss)
359+
@pytest.mark.xfail(tf.__version__ == "2.2.0-rc1", reason="TODO: Fix this test")
360+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
361+
def test_keras_compatibility():
362+
"""To test the compatibility of SequenceLoss with Keras's built-in
363+
training loops, we create a fake model which always outputs a pre-
364+
defined set of logits.
365+
366+
Then we check the calculated loss to be equal to the expected
367+
loss. Note that since the fake model doesn't have any trainable
368+
parameters, no matter how many steps we train it, it always
369+
outputs the same loss value.
370+
"""
371+
(
372+
batch_size,
373+
sequence_length,
374+
number_of_classes,
375+
logits,
376+
targets,
377+
weights,
378+
expected_loss,
379+
) = get_test_data()
380+
targets = tf.one_hot(targets, depth=number_of_classes)
381+
382+
def return_logits(x):
383+
logits_single_row = logits[0, :, :]
384+
logits_batch = tf.tile(
385+
tf.expand_dims(logits_single_row, 0), [tf.shape(x)[0], 1, 1]
386+
)
387+
return logits_batch
388+
389+
inp = tf.keras.layers.Input(shape=(sequence_length,))
390+
out = tf.keras.layers.Lambda(
391+
return_logits, output_shape=(sequence_length, number_of_classes),
392+
)(inp)
393+
model = tf.keras.models.Model(inp, out)
394+
395+
loss_obj = loss.SequenceLoss()
396+
model.compile(optimizer="adam", loss=loss_obj, sample_weight_mode="temporal")
397+
398+
# This is a fake input.
399+
x = tf.ones(shape=(batch_size, sequence_length))
400+
401+
h = model.fit(
402+
x, targets, sample_weight=weights, batch_size=batch_size, steps_per_epoch=1,
403+
)
404+
405+
calculated_loss = h.history["loss"][0]
406+
np.testing.assert_allclose(calculated_loss, expected_loss, rtol=1e-6, atol=1e-6)
381407

382408

383409
if __name__ == "__main__":

0 commit comments

Comments
 (0)