Skip to content

Commit 9dfdfb8

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix test in OSS env without CUDA device (#2688)
Summary: Pull Request resolved: #2688 # context * to fix OSS CPU test failure due to lack of CUDA device. Reviewed By: dstaay-fb Differential Revision: D68340773 fbshipit-source-id: 93b7dd03a16df13be1e333622b6f0346189778a3
1 parent 33168a1 commit 9dfdfb8

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

torchrec/sparse/tests/test_tensor_dict.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111
import unittest
1212

1313
import torch
14+
from hypothesis import given, settings, strategies as st, Verbosity
1415
from tensordict import TensorDict
1516
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1617
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
17-
from torchrec.sparse.tests.utils import repeat_test
1818

1919

2020
class TestTensorDIct(unittest.TestCase):
21-
@repeat_test(device_str=["cpu", "cuda", "meta"])
21+
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
22+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
23+
# pyre-ignore[56]
24+
@unittest.skipIf(
25+
torch.cuda.device_count() <= 0,
26+
"CUDA is not available",
27+
)
2228
def test_kjt_input(self, device_str: str) -> None:
2329
device = torch.device(device_str)
2430
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
@@ -30,7 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
3036
features = maybe_td_to_kjt(kjt)
3137
self.assertEqual(features, kjt)
3238

33-
@repeat_test(device_str=["cpu", "cuda", "meta"])
39+
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
40+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
41+
# pyre-ignore[56]
42+
@unittest.skipIf(
43+
torch.cuda.device_count() <= 0,
44+
"CUDA is not available",
45+
)
3446
def test_td_kjt(self, device_str: str) -> None:
3547
device = torch.device(device_str)
3648
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)

0 commit comments

Comments
 (0)