Skip to content

Commit 711e725

Browse files
Test ambiguous order in eager and tf.function. (#1377)
1 parent 3b41cfe commit 711e725

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tensorflow_addons/seq2seq/loss_test.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -343,17 +343,18 @@ def testZeroWeights(self):
343343
compare_total = np.zeros((self.batch_size, self.sequence_length))
344344
self.assertAllClose(compare_total, res)
345345

346-
def testAmbiguousOrder(self):
347-
with self.assertRaisesRegexp(ValueError, "because of ambiguous order"):
348-
with self.cached_session(use_gpu=True):
349-
self.setup()
350-
seq_loss = loss.SequenceLoss(
351-
average_across_timesteps=False,
352-
average_across_batch=True,
353-
sum_over_timesteps=True,
354-
sum_over_batch=False,
355-
)
356-
self.evaluate(seq_loss(self.targets, self.logits, self.weights))
346+
347+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
348+
def test_ambiguous_order():
349+
with pytest.raises(ValueError, match="because of ambiguous order"):
350+
_, _, _, logits, targets, weights, _ = get_test_data()
351+
seq_loss = loss.SequenceLoss(
352+
average_across_timesteps=False,
353+
average_across_batch=True,
354+
sum_over_timesteps=True,
355+
sum_over_batch=False,
356+
)
357+
seq_loss(targets, logits, weights).numpy()
357358

358359

359360
@pytest.mark.xfail(tf.__version__ == "2.2.0-rc1", reason="TODO: Fix this test")

0 commit comments

Comments
 (0)