Skip to content

Commit 80dc3e9

Browse files
Fix fasthmath precision issue (#1048)
* removed reassoc flag from fastmath * Add reset feature to config * Revised config value * replaced fastmath flags with config var * fixed format * Removed bad f-string * Replaced Raised with Returns in docstring * Add second attempt for assertion * minor change * Add condition to avoid revising fastmath when JIT is disabled * Removed support for input with type list to simplify function * Refactored the recompile process * removed blank lines * fixed typo * replaced hardcoded fastmath value with config var * revised function * renamed variable to improve readability * fixed bug * rename config to improve readability * revise func clear * revise func to recompile all njit functions * Adapt to changes in test function * add test * resolve coverage * resolve missing lines in coverage * Add test function to improve coverage * add fastmath module * revise test function to use fastmath module * fix minor issue * minor change to improve readability * Add fastmath default flags to config default * add reset function * rename function * adapt recent changes in test function * minor fixes * Check if DISABLE_JIT before getting fastmath * ignore lines for coverage check * Editorial fix * avoid .get(key) to get KeyError if it does not exist * add function to save cache * Add note to function * fix format * replace fastmath flag with config variable * add test function to check backward compatibility * skip test when JIT is disabled * rename test function * add conditional deprecation warning * add test function to check if cache can be saved after cache._clear() * remove old warning * add test for cache._clear * add wrapper around private functions * Raise OSError when NUMBA JIT is disabled during cache save * move warnings to public API * fix warning message * improved warning message * Add commit about addition config variables that are defined in __init__ * Revise test function to improve readability * Add test function for fastmath * Revised test functions * skip test if numba JIT is disabled * omit test functions that require NUMBA JIT * Removed the trivial test function * Raise warning instead of error to avoid interrupting the program * improve readability * remove intermediate variable * minor fixes * Add shell script code to check for harcoded fastmath flags * minor fix on indention
1 parent dfe4f63 commit 80dc3e9

17 files changed

+420
-45
lines changed

fastmath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_njit_funcs(pkg_dir):
1313
Parameters
1414
----------
1515
pkg_dir : str
16-
The path to the directory containing some .py files
16+
The path to the directory containing some .py files
1717
1818
Returns
1919
-------

stumpy/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import importlib
12
import os.path
23
from importlib.metadata import distribution
34
from site import getsitepackages
45

6+
import numba
57
from numba import cuda
68

9+
from . import cache, config
710
from .aamp import aamp # noqa: F401
811
from .aamp_mmotifs import aamp_mmotifs # noqa: F401
912
from .aamp_motifs import aamp_match, aamp_motifs # noqa: F401
@@ -32,6 +35,18 @@
3235
from .stumped import stumped # noqa: F401
3336
from .stumpi import stumpi # noqa: F401
3437

38+
# Get the default fastmath flags for all njit functions
39+
# and update the _STUMPY_DEFAULTS dictionary
40+
41+
if not numba.config.DISABLE_JIT: # pragma: no cover
42+
njit_funcs = cache.get_njit_funcs()
43+
for module_name, func_name in njit_funcs:
44+
module = importlib.import_module(f".{module_name}", package="stumpy")
45+
func = getattr(module, func_name)
46+
key = module_name + "." + func_name # e.g., core._mass
47+
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
48+
config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"]
49+
3550
if cuda.is_available():
3651
from .gpu_aamp import gpu_aamp # noqa: F401
3752
from .gpu_aamp_ostinato import gpu_aamp_ostinato # noqa: F401

stumpy/aamp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@njit(
1414
# "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :],"
1515
# "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)",
16-
fastmath=True,
16+
fastmath=config.STUMPY_FASTMATH_TRUE,
1717
)
1818
def _compute_diagonal(
1919
T_A,
@@ -186,7 +186,7 @@ def _compute_diagonal(
186186
@njit(
187187
# "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)",
188188
parallel=True,
189-
fastmath=True,
189+
fastmath=config.STUMPY_FASTMATH_TRUE,
190190
)
191191
def _aamp(
192192
T_A,

stumpy/cache.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55
import ast
66
import importlib
7+
import inspect
78
import pathlib
89
import site
910
import warnings
1011

12+
import numba
13+
1114
CACHE_WARNING = "Caching `numba` functions is purely for experimental purposes "
1215
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
1316
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
@@ -74,7 +77,15 @@ def _enable():
7477
-------
7578
None
7679
"""
77-
warnings.warn(CACHE_WARNING)
80+
frame = inspect.currentframe()
81+
caller_name = inspect.getouterframes(frame)[1].function
82+
if caller_name != "_save":
83+
msg = (
84+
"The 'cache._enable()' function is deprecated and no longer supported. "
85+
+ "Please use 'cache.save()' instead"
86+
)
87+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
88+
7889
njit_funcs = get_njit_funcs()
7990
for module_name, func_name in njit_funcs:
8091
module = importlib.import_module(f".{module_name}", package="stumpy")
@@ -94,12 +105,29 @@ def _clear():
94105
-------
95106
None
96107
"""
97-
warnings.warn(CACHE_WARNING)
98108
site_pkg_dir = site.getsitepackages()[0]
99109
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
100110
[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
101111

102112

113+
def clear():
114+
"""
115+
Clear numba cache directory
116+
117+
Parameters
118+
----------
119+
None
120+
121+
Returns
122+
-------
123+
None
124+
"""
125+
warnings.warn(CACHE_WARNING)
126+
_clear()
127+
128+
return
129+
130+
103131
def _get_cache():
104132
"""
105133
Retrieve a list of cached numba functions
@@ -117,3 +145,69 @@ def _get_cache():
117145
site_pkg_dir = site.getsitepackages()[0]
118146
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
119147
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
148+
149+
150+
def _recompile():
151+
"""
152+
Recompile all njit functions
153+
154+
Parameters
155+
----------
156+
None
157+
158+
Returns
159+
-------
160+
None
161+
162+
Notes
163+
-----
164+
If the `numba` cache is enabled, this results in saving (and/or overwriting)
165+
the cached numba functions to disk.
166+
"""
167+
for module_name, func_name in get_njit_funcs():
168+
module = importlib.import_module(f".{module_name}", package="stumpy")
169+
func = getattr(module, func_name)
170+
func.recompile()
171+
172+
return
173+
174+
175+
def _save():
176+
"""
177+
Save all njit functions
178+
179+
Parameters
180+
----------
181+
None
182+
183+
Returns
184+
-------
185+
None
186+
"""
187+
_enable()
188+
_recompile()
189+
190+
return
191+
192+
193+
def save():
194+
"""
195+
Save/overwrite all the cache data files of
196+
all-so-far compiled njit functions.
197+
198+
Parameters
199+
----------
200+
None
201+
202+
Returns
203+
-------
204+
None
205+
"""
206+
if numba.config.DISABLE_JIT:
207+
msg = "Could not save/cache function because NUMBA JIT is disabled"
208+
warnings.warn(msg)
209+
else:
210+
warnings.warn(CACHE_WARNING)
211+
_save()
212+
213+
return

stumpy/config.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,72 @@
22
# Copyright 2019 TD Ameritrade. Released under the terms of the 3-Clause BSD license.
33
# STUMPY is a trademark of TD Ameritrade IP Company, Inc. All rights reserved.
44

5+
import warnings
6+
57
import numpy as np
68

7-
STUMPY_THREADS_PER_BLOCK = 512
8-
STUMPY_MEAN_STD_NUM_CHUNKS = 1
9-
STUMPY_MEAN_STD_MAX_ITER = 10
10-
STUMPY_DENOM_THRESHOLD = 1e-14
11-
STUMPY_STDDEV_THRESHOLD = 1e-7
12-
STUMPY_P_NORM_THRESHOLD = 1e-14
13-
STUMPY_TEST_PRECISION = 5
14-
STUMPY_MAX_P_NORM_DISTANCE = np.finfo(np.float64).max
15-
STUMPY_MAX_DISTANCE = np.sqrt(STUMPY_MAX_P_NORM_DISTANCE)
16-
STUMPY_EXCL_ZONE_DENOM = 4
9+
_STUMPY_DEFAULTS = {
10+
"STUMPY_THREADS_PER_BLOCK": 512,
11+
"STUMPY_MEAN_STD_NUM_CHUNKS": 1,
12+
"STUMPY_MEAN_STD_MAX_ITER": 10,
13+
"STUMPY_DENOM_THRESHOLD": 1e-14,
14+
"STUMPY_STDDEV_THRESHOLD": 1e-7,
15+
"STUMPY_P_NORM_THRESHOLD": 1e-14,
16+
"STUMPY_TEST_PRECISION": 5,
17+
"STUMPY_MAX_P_NORM_DISTANCE": np.finfo(np.float64).max,
18+
"STUMPY_MAX_DISTANCE": np.sqrt(np.finfo(np.float64).max),
19+
"STUMPY_EXCL_ZONE_DENOM": 4,
20+
"STUMPY_FASTMATH_TRUE": True,
21+
"STUMPY_FASTMATH_FLAGS": {"nsz", "arcp", "contract", "afn", "reassoc"},
22+
}
23+
24+
# In addition to these configuration variables, there exist config variables
25+
# that have the default value of the fastmath flag of the njit functions. The
26+
# name of this config variable has the following format:
27+
# STUMPY_FASTMATH_<module_name>.<function_name>
28+
# See __init__.py for more details
29+
30+
STUMPY_THREADS_PER_BLOCK = _STUMPY_DEFAULTS["STUMPY_THREADS_PER_BLOCK"]
31+
STUMPY_MEAN_STD_NUM_CHUNKS = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_NUM_CHUNKS"]
32+
STUMPY_MEAN_STD_MAX_ITER = _STUMPY_DEFAULTS["STUMPY_MEAN_STD_MAX_ITER"]
33+
STUMPY_DENOM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_DENOM_THRESHOLD"]
34+
STUMPY_STDDEV_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_STDDEV_THRESHOLD"]
35+
STUMPY_P_NORM_THRESHOLD = _STUMPY_DEFAULTS["STUMPY_P_NORM_THRESHOLD"]
36+
STUMPY_TEST_PRECISION = _STUMPY_DEFAULTS["STUMPY_TEST_PRECISION"]
37+
STUMPY_MAX_P_NORM_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_P_NORM_DISTANCE"]
38+
STUMPY_MAX_DISTANCE = _STUMPY_DEFAULTS["STUMPY_MAX_DISTANCE"]
39+
STUMPY_EXCL_ZONE_DENOM = _STUMPY_DEFAULTS["STUMPY_EXCL_ZONE_DENOM"]
40+
STUMPY_FASTMATH_TRUE = _STUMPY_DEFAULTS["STUMPY_FASTMATH_TRUE"]
41+
STUMPY_FASTMATH_FLAGS = _STUMPY_DEFAULTS["STUMPY_FASTMATH_FLAGS"]
42+
43+
44+
def _reset(var=None):
45+
"""
46+
Reset the value of a configuration variable(s) to their default value(s)
47+
48+
Parameters
49+
----------
50+
var : str, default None
51+
The name of the configuration variable. If None, then all
52+
configuration variables are reset to their default values.
53+
54+
Returns
55+
-------
56+
None
57+
"""
58+
config_vars = [
59+
k for k, _ in globals().items() if k.isupper() and k.startswith("STUMPY")
60+
]
61+
62+
if var is None:
63+
for config_var in config_vars:
64+
globals()[config_var] = _STUMPY_DEFAULTS[config_var]
65+
elif var in config_vars:
66+
globals()[var] = _STUMPY_DEFAULTS[var]
67+
else: # pragma: no cover
68+
msg = (
69+
f"Configuration reset was skipped for unrecognized '_STUMPY_DEFAULT[{var}]'"
70+
)
71+
warnings.warn(msg)
72+
73+
return

0 commit comments

Comments
 (0)