@@ -343,17 +343,18 @@ def testZeroWeights(self):
343
343
compare_total = np .zeros ((self .batch_size , self .sequence_length ))
344
344
self .assertAllClose (compare_total , res )
345
345
346
- def testAmbiguousOrder (self ):
347
- with self .assertRaisesRegexp (ValueError , "because of ambiguous order" ):
348
- with self .cached_session (use_gpu = True ):
349
- self .setup ()
350
- seq_loss = loss .SequenceLoss (
351
- average_across_timesteps = False ,
352
- average_across_batch = True ,
353
- sum_over_timesteps = True ,
354
- sum_over_batch = False ,
355
- )
356
- self .evaluate (seq_loss (self .targets , self .logits , self .weights ))
346
+
347
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
348
+ def test_ambiguous_order ():
349
+ with pytest .raises (ValueError , match = "because of ambiguous order" ):
350
+ _ , _ , _ , logits , targets , weights , _ = get_test_data ()
351
+ seq_loss = loss .SequenceLoss (
352
+ average_across_timesteps = False ,
353
+ average_across_batch = True ,
354
+ sum_over_timesteps = True ,
355
+ sum_over_batch = False ,
356
+ )
357
+ seq_loss (targets , logits , weights ).numpy ()
357
358
358
359
359
360
@pytest .mark .xfail (tf .__version__ == "2.2.0-rc1" , reason = "TODO: Fix this test" )
0 commit comments