1
1
import unittest
2
2
from common_utils import TestCase , run_tests
3
+ from common_cuda import TEST_CUDA
3
4
import torch
5
+ import sys
4
6
5
7
6
8
def namedtensor_enabled ():
@@ -10,11 +12,98 @@ def namedtensor_enabled():
10
12
unittest .skipIf (not namedtensor_enabled (),
11
13
'PyTorch not compiled with namedtensor support' )
12
14
15
+ def pass_name_to_python_arg_parser (name ):
16
+ x = torch .empty (2 , names = (name ,))
17
+
18
+
13
19
class TestNamedTensor (TestCase ):
14
20
@skipIfNamedTensorDisabled
15
21
def test_trivial (self ):
16
22
pass
17
23
24
+ def _test_factory (self , factory , device ):
25
+ x = factory ([], device = device )
26
+ self .assertEqual (x .names , ())
27
+
28
+ x = factory (1 , 2 , 3 , device = device )
29
+ self .assertEqual (x .names , (None , None , None ))
30
+
31
+ x = factory (1 , 2 , 3 , names = None , device = device )
32
+ self .assertEqual (x .names , (None , None , None ))
33
+
34
+ x = factory (1 , 2 , 3 , names = ('N' , 'T' , 'D' ), device = device )
35
+ self .assertEqual (x .names , ('N' , 'T' , 'D' ))
36
+
37
+ x = factory (1 , 2 , 3 , names = ('N' , None , 'D' ), device = device )
38
+ self .assertEqual (x .names , ('N' , None , 'D' ))
39
+
40
+ with self .assertRaisesRegex (RuntimeError ,
41
+ 'must contain alphabetical characters and/or underscore' ):
42
+ x = factory (2 , names = ('?' ,), device = device )
43
+
44
+ with self .assertRaisesRegex (RuntimeError , 'Number of names' ):
45
+ x = factory (2 , 1 , names = ('N' ,), device = device )
46
+
47
+ with self .assertRaisesRegex (TypeError , 'invalid combination of arguments' ):
48
+ x = factory (2 , 1 , names = 'N' , device = device )
49
+
50
+
51
+ @skipIfNamedTensorDisabled
52
+ def test_empty (self ):
53
+ self ._test_factory (torch .empty , 'cpu' )
54
+
55
+ @skipIfNamedTensorDisabled
56
+ @unittest .skipIf (not TEST_CUDA , 'no CUDA' )
57
+ def test_empty_cuda (self ):
58
+ self ._test_factory (torch .empty , 'cuda' )
59
+
60
+ @skipIfNamedTensorDisabled
61
+ def test_using_seen_interned_string_doesnt_bump_refcount (self ):
62
+ def see_name ():
63
+ seen_name = 'N'
64
+ pass_name_to_python_arg_parser (seen_name )
65
+
66
+ see_name ()
67
+ seen_name = 'N'
68
+ old_refcnt = sys .getrefcount (seen_name )
69
+
70
+ pass_name_to_python_arg_parser (seen_name )
71
+
72
+ new_refcnt = sys .getrefcount (seen_name )
73
+ self .assertEqual (new_refcnt , old_refcnt )
74
+
75
+ @skipIfNamedTensorDisabled
76
+ def test_using_unseen_interned_string_bumps_refcount_permanently (self ):
77
+ # Please don't use this as a name in a different test.
78
+ unseen_name = 'abcdefghi'
79
+ old_refcnt = sys .getrefcount (unseen_name )
80
+
81
+ pass_name_to_python_arg_parser (unseen_name )
82
+
83
+ new_refcnt = sys .getrefcount (unseen_name )
84
+ self .assertEqual (new_refcnt , old_refcnt + 1 )
85
+
86
+ @skipIfNamedTensorDisabled
87
+ def test_using_unseen_uninterned_string_refcounts (self ):
88
+ # Please don't use this as a name in a different test.
89
+ # non-compile-time constants are not interned
90
+ unseen_name = '' .join (['abc' , 'def' , 'ghi' , 'jkl' ])
91
+ interned_unseen_name = 'abcdefghijkl'
92
+ self .assertFalse (unseen_name is interned_unseen_name )
93
+
94
+ old_uninterned_refcnt = sys .getrefcount (unseen_name )
95
+ old_interned_refcnt = sys .getrefcount (interned_unseen_name )
96
+
97
+ pass_name_to_python_arg_parser (unseen_name )
98
+
99
+ new_uninterned_refcnt = sys .getrefcount (unseen_name )
100
+ new_interned_refcnt = sys .getrefcount (interned_unseen_name )
101
+
102
+ # Internally, PyTorch should not hold a reference to the uninterned string
103
+ self .assertEqual (new_uninterned_refcnt , old_uninterned_refcnt )
104
+
105
+ # Instead, we should hold a new reference to the interned version.
106
+ self .assertEqual (new_interned_refcnt , old_interned_refcnt + 1 )
18
107
19
108
if __name__ == '__main__' :
20
109
run_tests ()
0 commit comments