Skip to content

Commit b250f5d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix test in OSS env without CUDA device (#2688)
Summary: # context * to fix OSS CPU test failure due to lack of CUDA device. Reviewed By: dstaay-fb Differential Revision: D68340773
1 parent 33168a1 commit b250f5d

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

torchrec/sparse/tests/test_tensor_dict.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,20 @@
1111
import unittest
1212

1313
import torch
14+
from hypothesis import given, settings, strategies as st, Verbosity
1415
from tensordict import TensorDict
1516
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1617
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
17-
from torchrec.sparse.tests.utils import repeat_test
1818

1919

2020
class TestTensorDIct(unittest.TestCase):
21-
@repeat_test(device_str=["cpu", "cuda", "meta"])
22-
def test_kjt_input(self, device_str: str) -> None:
23-
device = torch.device(device_str)
21+
@unittest.skipIf(
22+
torch.cuda.device_count() <= 0,
23+
"CUDA is not available",
24+
)
25+
@given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]]))
26+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
27+
def test_kjt_input(self, device: torch.device) -> None:
2428
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
2529
kjt = KeyedJaggedTensor.from_offsets_sync(
2630
keys=["f1", "f2", "f3"],
@@ -30,9 +34,13 @@ def test_kjt_input(self, device_str: str) -> None:
3034
features = maybe_td_to_kjt(kjt)
3135
self.assertEqual(features, kjt)
3236

33-
@repeat_test(device_str=["cpu", "cuda", "meta"])
34-
def test_td_kjt(self, device_str: str) -> None:
35-
device = torch.device(device_str)
37+
@unittest.skipIf(
38+
torch.cuda.device_count() <= 0,
39+
"CUDA is not available",
40+
)
41+
@given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]]))
42+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
43+
def test_td_kjt(self, device: torch.device) -> None:
3644
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
3745
lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device)
3846
data = {

0 commit comments

Comments
 (0)