Skip to content

Commit 7108218

Browse files
skrahfacebook-github-bot
authored andcommitted
Fix flaky nuclear_norm() test (pytorch#21638)
Summary: Try to fix a sporadic failure on some CIs. I've run this test hundreds of times on my machine (GeForce 1060, MAGMA) but I cannot reproduce this. Pull Request resolved: pytorch#21638 Differential Revision: D15827779 Pulled By: ezyang fbshipit-source-id: 3586075e48907b3b84a101c560a34cc733514a02
1 parent ff8c3fd commit 7108218

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

test/test_cuda.py

+2
Original file line numberDiff line numberDiff line change
@@ -2692,10 +2692,12 @@ def test_norm(self):
26922692

26932693
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
26942694
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2695+
@skipCUDANonDefaultStreamIf(True)
26952696
def test_nuclear_norm_axes_small_brute_force(self):
26962697
_TestTorchMixin._test_nuclear_norm_axes(self, device='cuda')
26972698

26982699
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2700+
@skipCUDANonDefaultStreamIf(True)
26992701
def test_nuclear_norm_exceptions(self):
27002702
_TestTorchMixin._test_nuclear_norm_exceptions(self, device='cuda')
27012703

test/test_torch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -991,14 +991,14 @@ def check_single_nuclear_norm(x, axes):
991991
ans = torch.norm(x, "nuc", dim=axes)
992992
self.assertTrue(ans.is_contiguous())
993993
self.assertEqual(ans.shape, expected.shape)
994-
self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
994+
self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True))
995995

996996
out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
997997
ans = torch.norm(x, "nuc", dim=axes, out=out)
998998
self.assertIs(ans, out)
999999
self.assertTrue(ans.is_contiguous())
10001000
self.assertEqual(ans.shape, expected.shape)
1001-
self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
1001+
self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True))
10021002

10031003
for n in range(1, 3):
10041004
for m in range(1, 3):
@@ -1026,7 +1026,7 @@ def check_single_nuclear_norm(x, axes):
10261026
check_single_nuclear_norm(x, axes)
10271027

10281028
# 3d, inner dimensions Fortran
1029-
y = torch.randn(o, n, m, device=device).transpose(-1, -2)
1029+
x = torch.randn(o, m, n, device=device).transpose(-1, -2)
10301030
check_single_nuclear_norm(x, axes)
10311031

10321032
# 3d, inner dimensions non-contiguous

0 commit comments

Comments
 (0)