Skip to content

Commit 25d5a89

Browse files
authored
[CI] Lint check_binary_symbols (#2007)
* [CI] Lint check_binary_symbols #2001 added trailing spaces to it, so let's ruff it by default * Fix lint issues
1 parent 9452ae2 commit 25d5a89

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

.lintrunner.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ merge_base_with = "origin/main"
22

33
[[linter]]
44
code = 'RUFF'
5-
include_patterns = ['test/smoke_test/*.py', 'aarch64_linux/*.py']
5+
include_patterns = ['test/smoke_test/*.py', 'aarch64_linux/*.py', 'test/check_binary_symbols.py']
66
command = [
77
'python3',
88
'tools/linter/adapters/ruff_linter.py',

test/check_binary_symbols.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@
3131
"torch::",
3232
)
3333

34-
LIBTORCH_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, CXX11_SYMBOLS)]
3534

36-
LIBTORCH_PRE_CXX11_PATTERNS = [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, PRE_CXX11_SYMBOLS)]
35+
def _apply_libtorch_symbols(symbols):
36+
return [re.compile(f"{x}.*{y}") for (x,y) in itertools.product(LIBTORCH_NAMESPACE_LIST, symbols)]
37+
38+
39+
LIBTORCH_CXX11_PATTERNS = _apply_libtorch_symbols(CXX11_SYMBOLS)
40+
41+
LIBTORCH_PRE_CXX11_PATTERNS = _apply_libtorch_symbols(PRE_CXX11_SYMBOLS)
3742

3843
@functools.lru_cache(100)
3944
def get_symbols(lib :str ) -> List[Tuple[str, str, str]]:
@@ -45,7 +50,7 @@ def get_symbols(lib :str ) -> List[Tuple[str, str, str]]:
4550
def grep_symbols(lib: str, patterns: List[Any]) -> List[str]:
4651
def _grep_symbols(symbols: List[Tuple[str, str, str]], patterns: List[Any]) -> List[str]:
4752
rc = []
48-
for s_addr, s_type, s_name in symbols:
53+
for _s_addr, _s_type, s_name in symbols:
4954
for pattern in patterns:
5055
if pattern.match(s_name):
5156
rc.append(s_name)
@@ -54,9 +59,12 @@ def _grep_symbols(symbols: List[Tuple[str, str, str]], patterns: List[Any]) -> L
5459
all_symbols = get_symbols(lib)
5560
num_workers= 32
5661
chunk_size = (len(all_symbols) + num_workers - 1 ) // num_workers
62+
def _get_symbols_chunk(i):
63+
return all_symbols[i * chunk_size : (i + 1) * chunk_size]
64+
5765
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
58-
tasks = [executor.submit(_grep_symbols, all_symbols[i * chunk_size : (i + 1) * chunk_size], patterns) for i in range(num_workers)]
59-
return sum((x.result() for x in tasks), [])
66+
tasks = [executor.submit(_grep_symbols, _get_symbols_chunk(i), patterns) for i in range(num_workers)]
67+
return functools.reduce(list.__add__, (x.result() for x in tasks), [])
6068

6169
def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True) -> None:
6270
print(f"lib: {lib}")
@@ -79,13 +87,13 @@ def check_lib_symbols_for_abi_correctness(lib: str, pre_cxx11_abi: bool = True)
7987

8088
def main() -> None:
8189
if "install_root" in os.environ:
82-
install_root = Path(os.getenv("install_root"))
90+
install_root = Path(os.getenv("install_root")) # noqa: SIM112
8391
else:
8492
if os.getenv("PACKAGE_TYPE") == "libtorch":
8593
install_root = Path(os.getcwd())
8694
else:
8795
install_root = Path(distutils.sysconfig.get_python_lib()) / "torch"
88-
96+
8997
libtorch_cpu_path = install_root / "lib" / "libtorch_cpu.so"
9098
pre_cxx11_abi = "cxx11-abi" not in os.getenv("DESIRED_DEVTOOLSET", "")
9199
check_lib_symbols_for_abi_correctness(libtorch_cpu_path, pre_cxx11_abi)

0 commit comments

Comments
 (0)