31
31
"torch::" ,
32
32
)
33
33
34
- LIBTORCH_CXX11_PATTERNS = [re .compile (f"{ x } .*{ y } " ) for (x ,y ) in itertools .product (LIBTORCH_NAMESPACE_LIST , CXX11_SYMBOLS )]
35
34
36
- LIBTORCH_PRE_CXX11_PATTERNS = [re .compile (f"{ x } .*{ y } " ) for (x ,y ) in itertools .product (LIBTORCH_NAMESPACE_LIST , PRE_CXX11_SYMBOLS )]
35
+ def _apply_libtorch_symbols (symbols ):
36
+ return [re .compile (f"{ x } .*{ y } " ) for (x ,y ) in itertools .product (LIBTORCH_NAMESPACE_LIST , symbols )]
37
+
38
+
39
+ LIBTORCH_CXX11_PATTERNS = _apply_libtorch_symbols (CXX11_SYMBOLS )
40
+
41
+ LIBTORCH_PRE_CXX11_PATTERNS = _apply_libtorch_symbols (PRE_CXX11_SYMBOLS )
37
42
38
43
@functools .lru_cache (100 )
39
44
def get_symbols (lib :str ) -> List [Tuple [str , str , str ]]:
@@ -45,7 +50,7 @@ def get_symbols(lib :str ) -> List[Tuple[str, str, str]]:
45
50
def grep_symbols (lib : str , patterns : List [Any ]) -> List [str ]:
46
51
def _grep_symbols (symbols : List [Tuple [str , str , str ]], patterns : List [Any ]) -> List [str ]:
47
52
rc = []
48
- for s_addr , s_type , s_name in symbols :
53
+ for _s_addr , _s_type , s_name in symbols :
49
54
for pattern in patterns :
50
55
if pattern .match (s_name ):
51
56
rc .append (s_name )
@@ -54,9 +59,12 @@ def _grep_symbols(symbols: List[Tuple[str, str, str]], patterns: List[Any]) -> L
54
59
all_symbols = get_symbols (lib )
55
60
num_workers = 32
56
61
chunk_size = (len (all_symbols ) + num_workers - 1 ) // num_workers
62
+ def _get_symbols_chunk (i ):
63
+ return all_symbols [i * chunk_size : (i + 1 ) * chunk_size ]
64
+
57
65
with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
58
- tasks = [executor .submit (_grep_symbols , all_symbols [ i * chunk_size : ( i + 1 ) * chunk_size ] , patterns ) for i in range (num_workers )]
59
- return sum ( (x .result () for x in tasks ), [])
66
+ tasks = [executor .submit (_grep_symbols , _get_symbols_chunk ( i ) , patterns ) for i in range (num_workers )]
67
+ return functools . reduce ( list . __add__ , (x .result () for x in tasks ), [])
60
68
61
69
def check_lib_symbols_for_abi_correctness (lib : str , pre_cxx11_abi : bool = True ) -> None :
62
70
print (f"lib: { lib } " )
@@ -79,13 +87,13 @@ def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True)
79
87
80
88
def main () -> None :
81
89
if "install_root" in os .environ :
82
- install_root = Path (os .getenv ("install_root" ))
90
+ install_root = Path (os .getenv ("install_root" )) # noqa: SIM112
83
91
else :
84
92
if os .getenv ("PACKAGE_TYPE" ) == "libtorch" :
85
93
install_root = Path (os .getcwd ())
86
94
else :
87
95
install_root = Path (distutils .sysconfig .get_python_lib ()) / "torch"
88
-
96
+
89
97
libtorch_cpu_path = install_root / "lib" / "libtorch_cpu.so"
90
98
pre_cxx11_abi = "cxx11-abi" not in os .getenv ("DESIRED_DEVTOOLSET" , "" )
91
99
check_lib_symbols_for_abi_correctness (libtorch_cpu_path , pre_cxx11_abi )
0 commit comments