1111import unittest
1212
1313import torch
14+ from hypothesis import given , settings , strategies as st , Verbosity
1415from tensordict import TensorDict
1516from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
1617from torchrec .sparse .tensor_dict import maybe_td_to_kjt
17- from torchrec .sparse .tests .utils import repeat_test
1818
1919
2020class TestTensorDIct (unittest .TestCase ):
21- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
22- def test_kjt_input (self , device_str : str ) -> None :
23- device = torch .device (device_str )
21+ @unittest .skipIf (
22+ torch .cuda .device_count () <= 0 ,
23+ "CUDA is not available" ,
24+ )
25+ @given (device = st .sampled_from ([torch .device (d ) for d in ["cpu" , "cuda" , "meta" ]]))
26+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
27+ def test_kjt_input (self , device : torch .device ) -> None :
2428 values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
2529 kjt = KeyedJaggedTensor .from_offsets_sync (
2630 keys = ["f1" , "f2" , "f3" ],
@@ -30,9 +34,13 @@ def test_kjt_input(self, device_str: str) -> None:
3034 features = maybe_td_to_kjt (kjt )
3135 self .assertEqual (features , kjt )
3236
33- @repeat_test (device_str = ["cpu" , "cuda" , "meta" ])
34- def test_td_kjt (self , device_str : str ) -> None :
35- device = torch .device (device_str )
37+ @unittest .skipIf (
38+ torch .cuda .device_count () <= 0 ,
39+ "CUDA is not available" ,
40+ )
41+ @given (device = st .sampled_from ([torch .device (d ) for d in ["cpu" , "cuda" , "meta" ]]))
42+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
43+ def test_td_kjt (self , device : torch .device ) -> None :
3644 values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
3745 lengths = torch .tensor ([2 , 0 , 1 , 1 , 1 , 2 ], device = device )
3846 data = {
0 commit comments