Skip to content

Commit 80544c7

Browse files
committed
Add second attempt for assertion
1 parent 2774302 commit 80544c7

File tree

1 file changed

+64
-13
lines changed

1 file changed

+64
-13
lines changed

tests/test_precision.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import importlib
23
from unittest.mock import patch
34

45
import naive
@@ -8,7 +9,7 @@
89
from numba import cuda
910

1011
import stumpy
11-
from stumpy import config, core
12+
from stumpy import cache, config, core
1213

1314
try:
1415
from numba.errors import NumbaPerformanceWarning
@@ -146,20 +147,70 @@ def test_snippets():
146147
cmp_regimes,
147148
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
148149

149-
npt.assert_almost_equal(
150-
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
151-
)
152-
npt.assert_almost_equal(
153-
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
150+
# Revise fastmath flag, recompile, and re-calculate snippets,
151+
# and then revert the changes
152+
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
153+
core._calculate_squared_distance.targetoptions["fastmath"] = (
154+
config.STUMPY_FASTMATH_FLAGS
154155
)
155-
npt.assert_almost_equal(
156-
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
157-
)
158-
npt.assert_almost_equal(
159-
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
156+
njit_funcs = cache.get_njit_funcs()
157+
for module_name, func_name in njit_funcs:
158+
module = importlib.import_module(f".{module_name}", package="stumpy")
159+
func = getattr(module, func_name)
160+
func.recompile()
161+
162+
(
163+
cmp_snippets_NOreassoc,
164+
cmp_indices_NOreassoc,
165+
cmp_profiles_NOreassoc,
166+
cmp_fractions_NOreassoc,
167+
cmp_areas_NOreassoc,
168+
cmp_regimes_NOreassoc,
169+
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
170+
171+
config._reset("STUMPY_FASTMATH_FLAGS")
172+
core._calculate_squared_distance.targetoptions["fastmath"] = (
173+
config.STUMPY_FASTMATH_FLAGS
160174
)
161-
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
162-
npt.assert_almost_equal(ref_regimes, cmp_regimes)
175+
for module_name, func_name in njit_funcs:
176+
module = importlib.import_module(f".{module_name}", package="stumpy")
177+
func = getattr(module, func_name)
178+
func.recompile()
179+
180+
if np.allclose(ref_snippets, cmp_snippets):
181+
npt.assert_almost_equal(
182+
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
183+
)
184+
npt.assert_almost_equal(
185+
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
186+
)
187+
npt.assert_almost_equal(
188+
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
189+
)
190+
npt.assert_almost_equal(
191+
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
192+
)
193+
npt.assert_almost_equal(
194+
ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
195+
)
196+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
197+
else:
198+
npt.assert_almost_equal(
199+
ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
200+
)
201+
npt.assert_almost_equal(
202+
ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
203+
)
204+
npt.assert_almost_equal(
205+
ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
206+
)
207+
npt.assert_almost_equal(
208+
ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
209+
)
210+
npt.assert_almost_equal(
211+
ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
212+
)
213+
npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc)
163214

164215

165216
@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)

0 commit comments

Comments
 (0)