11
11
import unittest
12
12
13
13
import torch
14
+ from hypothesis import given , settings , strategies as st , Verbosity
14
15
from tensordict import TensorDict
15
16
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
16
17
from torchrec .sparse .tensor_dict import maybe_td_to_kjt
17
- from torchrec .sparse .tests .utils import repeat_test
18
18
19
19
20
20
class TestTensorDIct (unittest .TestCase ):
21
- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
22
- def test_kjt_input (self , device_str : str ) -> None :
23
- device = torch .device (device_str )
21
+ @unittest .skipIf (
22
+ torch .cuda .device_count () <= 0 ,
23
+ "CUDA is not available" ,
24
+ )
25
+ @given (device = st .sampled_from ([torch .device (d ) for d in ["cpu" , "cuda" , "meta" ]]))
26
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
27
+ def test_kjt_input (self , device : torch .device ) -> None :
24
28
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
25
29
kjt = KeyedJaggedTensor .from_offsets_sync (
26
30
keys = ["f1" , "f2" , "f3" ],
@@ -30,9 +34,13 @@ def test_kjt_input(self, device_str: str) -> None:
30
34
features = maybe_td_to_kjt (kjt )
31
35
self .assertEqual (features , kjt )
32
36
33
- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
34
- def test_td_kjt (self , device_str : str ) -> None :
35
- device = torch .device (device_str )
37
+ @unittest .skipIf (
38
+ torch .cuda .device_count () <= 0 ,
39
+ "CUDA is not available" ,
40
+ )
41
+ @given (device = st .sampled_from ([torch .device (d ) for d in ["cpu" , "cuda" , "meta" ]]))
42
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
43
+ def test_td_kjt (self , device : torch .device ) -> None :
36
44
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
37
45
lengths = torch .tensor ([2 , 0 , 1 , 1 , 1 , 2 ], device = device )
38
46
data = {
0 commit comments