@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
14021402 assert spec2 .zero ().shape == spec2 .shape
14031403
14041404 def test_non_tensor (self ):
1405- spec = NonTensor ((3 , 4 ), device = "cpu" )
1405+ spec = NonTensor ((3 , 4 ), device = "cpu" , example_data = "example_data" )
14061406 assert (
14071407 spec .expand (2 , 3 , 4 )
14081408 == spec .expand ((2 , 3 , 4 ))
1409- == NonTensor ((2 , 3 , 4 ), device = "cpu" )
1409+ == NonTensor ((2 , 3 , 4 ), device = "cpu" , example_data = "example_data" )
14101410 )
1411+ assert spec .expand (2 , 3 , 4 ).example_data == "example_data"
14111412
14121413 @pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
14131414 @pytest .mark .parametrize ("shape2" , [(), (10 ,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
16071608 assert spec is not spec .clone ()
16081609
16091610 def test_non_tensor (self ):
1610- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
1611+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
16111612 assert spec .clone () == spec
16121613 assert spec .clone () is not spec
1614+ assert spec .clone ().example_data == "example_data"
16131615
16141616 @pytest .mark .parametrize ("shape1" , [None , (), (5 ,)])
16151617 def test_onehot (
@@ -1840,9 +1842,10 @@ def test_multionehot(
18401842 spec .unbind (- 1 )
18411843
18421844 def test_non_tensor (self ):
1843- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
1845+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
18441846 assert spec .unbind (1 )[0 ] == spec [:, 0 ]
18451847 assert spec .unbind (1 )[0 ] is not spec [:, 0 ]
1848+ assert spec .unbind (1 )[0 ].example_data == "example_data"
18461849
18471850 @pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
18481851 def test_onehot (
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
20012004 assert spec .to (device ).device == device
20022005
20032006 def test_non_tensor (self , device ):
2004- spec = NonTensor (shape = (3 , 4 ), device = "cpu" )
2007+ spec = NonTensor (shape = (3 , 4 ), device = "cpu" , example_data = "example_data" )
20052008 assert spec .to (device ).device == device
2009+ assert spec .to (device ).example_data == "example_data"
20062010
20072011 @pytest .mark .parametrize ("shape1" , [(5 ,), (5 , 6 )])
20082012 def test_onehot (self , shape1 , device ):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
22622266 assert r .shape == c .shape
22632267
22642268 def test_stack_non_tensor (self , shape , stack_dim ):
2265- spec0 = NonTensor (shape = shape , device = "cpu" )
2266- spec1 = NonTensor (shape = shape , device = "cpu" )
2269+ spec0 = NonTensor (shape = shape , device = "cpu" , example_data = "example_data" )
2270+ spec1 = NonTensor (shape = shape , device = "cpu" , example_data = "example_data" )
22672271 new_spec = torch .stack ([spec0 , spec1 ], stack_dim )
22682272 shape_insert = list (shape )
22692273 shape_insert .insert (stack_dim , 2 )
22702274 assert new_spec .shape == torch .Size (shape_insert )
22712275 assert new_spec .device == torch .device ("cpu" )
2276+ assert new_spec .example_data == "example_data"
22722277
22732278 def test_stack_onehot (self , shape , stack_dim ):
22742279 n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
36423647
36433648class TestNonTensorSpec :
36443649 def test_sample (self ):
3645- nts = NonTensor (shape = (3 , 4 ))
3650+ nts = NonTensor (shape = (3 , 4 ), example_data = "example_data" )
36463651 assert nts .one ((2 ,)).shape == (2 , 3 , 4 )
36473652 assert nts .rand ((2 ,)).shape == (2 , 3 , 4 )
36483653 assert nts .zero ((2 ,)).shape == (2 , 3 , 4 )
3654+ assert nts .one ((2 ,)).data == "example_data"
3655+ assert nts .rand ((2 ,)).data == "example_data"
3656+ assert nts .zero ((2 ,)).data == "example_data"
3657+
3658+ def test_example_data_ineq (self ):
3659+ nts0 = NonTensor (shape = (3 , 4 ), example_data = "example_data" )
3660+ nts1 = NonTensor (shape = (3 , 4 ), example_data = "example_data 2" )
3661+ assert nts0 != nts1
36493662
36503663
36513664@pytest .mark .skipif (not torch .cuda .is_available (), reason = "not cuda device" )
0 commit comments