Skip to content

Commit 70e4e70

Browse files
Fixed #1059 missing fastmath (#1060)
* add function to check fastmath * revise fastmath script and add support for reading arg from command line * minor fixes * rename function to improve readability * simplify code by passing boolean value * fix to catch njit functions with decorator * use regex to find njit functions * minor change * revise code to detect bare njit decorator * minor fix * fix path * add missing fastmath * revise fastmath flag * Improve ValueError msg * fix format * enable function to accept path as input * pass param via CLI, and some minor changes * adapt changes in test script * use type str for the param pkg_dir * minor changes * Revised string concatenation in error message
1 parent ce0cd8c commit 70e4e70

File tree

4 files changed

+137
-12
lines changed

4 files changed

+137
-12
lines changed

fastmath.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import ast
5+
import importlib
6+
import pathlib
7+
8+
9+
def get_njit_funcs(pkg_dir):
10+
"""
11+
Identify all njit functions
12+
13+
Parameters
14+
----------
15+
pkg_dir : str
16+
The path to the directory containing some .py files
17+
18+
Returns
19+
-------
20+
njit_funcs : list
21+
A list of all njit functions, where each element is a tuple of the form
22+
(module_name, func_name)
23+
"""
24+
ignore_py_files = ["__init__", "__pycache__"]
25+
pkg_dir = pathlib.Path(pkg_dir)
26+
27+
module_names = []
28+
for fname in pkg_dir.iterdir():
29+
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
30+
module_names.append(fname.stem)
31+
32+
njit_funcs = []
33+
for module_name in module_names:
34+
filepath = pkg_dir / f"{module_name}.py"
35+
file_contents = ""
36+
with open(filepath, encoding="utf8") as f:
37+
file_contents = f.read()
38+
module = ast.parse(file_contents)
39+
for node in module.body:
40+
if isinstance(node, ast.FunctionDef):
41+
func_name = node.name
42+
for decorator in node.decorator_list:
43+
decorator_name = None
44+
if isinstance(decorator, ast.Name):
45+
# Bare decorator
46+
decorator_name = decorator.id
47+
if isinstance(decorator, ast.Call) and isinstance(
48+
decorator.func, ast.Name
49+
):
50+
# Decorator is a function
51+
decorator_name = decorator.func.id
52+
53+
if decorator_name == "njit":
54+
njit_funcs.append((module_name, func_name))
55+
56+
return njit_funcs
57+
58+
59+
def check_fastmath(pkg_dir, pkg_name):
60+
"""
61+
Check if all njit functions have the `fastmath` flag set
62+
63+
Parameters
64+
----------
65+
pkg_dir : str
66+
The path to the directory containing some .py files
67+
68+
pkg_name : str
69+
The name of the package
70+
71+
Returns
72+
-------
73+
None
74+
"""
75+
missing_fastmath = [] # list of njit functions with missing fastmath flags
76+
for module_name, func_name in get_njit_funcs(pkg_dir):
77+
module = importlib.import_module(f".{module_name}", package=pkg_name)
78+
func = getattr(module, func_name)
79+
if "fastmath" not in func.targetoptions.keys():
80+
missing_fastmath.append(f"{module_name}.{func_name}")
81+
82+
if len(missing_fastmath) > 0:
83+
msg = (
84+
"Found one or more `@njit` functions that are missing the `fastmath` flag. "
85+
+ f"The functions are:\n {missing_fastmath}\n"
86+
)
87+
raise ValueError(msg)
88+
89+
return
90+
91+
92+
if __name__ == "__main__":
93+
parser = argparse.ArgumentParser()
94+
parser.add_argument("--check", dest="pkg_dir")
95+
args = parser.parse_args()
96+
97+
if args.pkg_dir:
98+
pkg_dir = pathlib.Path(args.pkg_dir)
99+
pkg_name = pkg_dir.name
100+
check_fastmath(str(pkg_dir), pkg_name)

stumpy/cache.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import ast
66
import importlib
77
import pathlib
8-
import pkgutil
98
import site
109
import warnings
1110

@@ -28,13 +27,17 @@ def get_njit_funcs():
2827
out : list
2928
A list of (`module_name`, `func_name`) pairs
3029
"""
30+
ignore_py_files = ["__init__", "__pycache__"]
31+
3132
pkg_dir = pathlib.Path(__file__).parent
32-
module_names = [name for _, name, _ in pkgutil.iter_modules([str(pkg_dir)])]
33+
module_names = []
34+
for fname in pkg_dir.iterdir():
35+
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
36+
module_names.append(fname.stem)
3337

3438
njit_funcs = []
35-
3639
for module_name in module_names:
37-
filepath = pathlib.Path(__file__).parent / f"{module_name}.py"
40+
filepath = pkg_dir / f"{module_name}.py"
3841
file_contents = ""
3942
with open(filepath, encoding="utf8") as f:
4043
file_contents = f.read()
@@ -43,11 +46,18 @@ def get_njit_funcs():
4346
if isinstance(node, ast.FunctionDef):
4447
func_name = node.name
4548
for decorator in node.decorator_list:
49+
decorator_name = None
50+
if isinstance(decorator, ast.Name):
51+
# Bare decorator
52+
decorator_name = decorator.id
4653
if isinstance(decorator, ast.Call) and isinstance(
4754
decorator.func, ast.Name
4855
):
49-
if decorator.func.id == "njit":
50-
njit_funcs.append((module_name, func_name))
56+
# Decorator is a function
57+
decorator_name = decorator.func.id
58+
59+
if decorator_name == "njit":
60+
njit_funcs.append((module_name, func_name))
5161

5262
return njit_funcs
5363

stumpy/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,7 @@ def _count_diagonal_ndist(diags, m, n_A, n_B):
23562356

23572357
@njit(
23582358
# "i8[:, :](i8[:], i8, b1)"
2359+
fastmath=True
23592360
)
23602361
def _get_array_ranges(a, n_chunks, truncate):
23612362
"""
@@ -2404,6 +2405,7 @@ def _get_array_ranges(a, n_chunks, truncate):
24042405

24052406
@njit(
24062407
# "i8[:, :](i8, i8, b1)"
2408+
fastmath=True
24072409
)
24082410
def _get_ranges(size, n_chunks, truncate):
24092411
"""
@@ -3256,7 +3258,7 @@ def _select_P_ABBA_value(P_ABBA, k, custom_func=None):
32563258
return MPdist
32573259

32583260

3259-
@njit()
3261+
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
32603262
def _merge_topk_PI(PA, PB, IA, IB):
32613263
"""
32623264
Merge two top-k matrix profiles `PA` and `PB`, and update `PA` (in place).
@@ -3329,7 +3331,7 @@ def _merge_topk_PI(PA, PB, IA, IB):
33293331
IA[i] = tmp_I
33303332

33313333

3332-
@njit()
3334+
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
33333335
def _merge_topk_ρI(ρA, ρB, IA, IB):
33343336
"""
33353337
Merge two top-k pearson profiles `ρA` and `ρB`, and update `ρA` (in place).
@@ -3403,7 +3405,7 @@ def _merge_topk_ρI(ρA, ρB, IA, IB):
34033405
IA[i] = tmp_I
34043406

34053407

3406-
@njit()
3408+
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
34073409
def _shift_insert_at_index(a, idx, v, shift="right"):
34083410
"""
34093411
If `shift=right` (default), all elements in `a[idx:]` are shifted to the right by
@@ -4379,7 +4381,7 @@ def get_ray_nworkers(ray_client):
43794381
return int(ray_client.cluster_resources().get("CPU"))
43804382

43814383

4382-
@njit
4384+
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
43834385
def _update_incremental_PI(D, P, I, excl_zone, n_appended=0):
43844386
"""
43854387
Given the 1D array distance profile, `D`, of the last subsequence of T,

test.sh

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ check_print()
9393
fi
9494
}
9595

96+
check_fastmath()
97+
{
98+
echo "Checking Missing fastmath flags in njit functions"
99+
./fastmath.py --check stumpy
100+
check_errs $?
101+
}
102+
96103
check_naive()
97104
{
98105
# Check if there are any naive implementations not at start of test file
@@ -146,14 +153,14 @@ set_ray_coveragerc()
146153
show_coverage_report()
147154
{
148155
set_ray_coveragerc
149-
coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
156+
coverage report -m --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
150157
}
151158

152159
gen_coverage_xml_report()
153160
{
154161
# This function saves the coverage report in Cobertura XML format, which is compatible with codecov
155162
set_ray_coveragerc
156-
coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
163+
coverage xml -o $fcoveragexml --fail-under=100 --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
157164
}
158165

159166
test_custom()
@@ -333,6 +340,12 @@ check_print
333340
check_naive
334341
check_ray
335342

343+
344+
if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then
345+
check_fastmath
346+
fi
347+
348+
336349
if [[ $test_mode == "notebooks" ]]; then
337350
echo "Executing Tutorial Notebooks Only"
338351
convert_notebooks

0 commit comments

Comments
 (0)