Skip to content

Commit 976ec13

Browse files
committed
Add condition to avoid revising fastmath when JIT is disabled
1 parent a080495 commit 976ec13

File tree

1 file changed

+38
-31
lines changed

1 file changed

+38
-31
lines changed

tests/test_precision.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch
44

55
import naive
6+
import numba
67
import numpy as np
78
import numpy.testing as npt
89
import pytest
@@ -138,6 +139,7 @@ def test_snippets():
138139
) = naive.mpdist_snippets(
139140
T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func
140141
)
142+
141143
(
142144
cmp_snippets,
143145
cmp_indices,
@@ -147,37 +149,7 @@ def test_snippets():
147149
cmp_regimes,
148150
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
149151

150-
# Revise fastmath flag, recompile, and re-calculate snippets,
151-
# and then revert the changes
152-
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
153-
core._calculate_squared_distance.targetoptions["fastmath"] = (
154-
config.STUMPY_FASTMATH_FLAGS
155-
)
156-
njit_funcs = cache.get_njit_funcs()
157-
for module_name, func_name in njit_funcs:
158-
module = importlib.import_module(f".{module_name}", package="stumpy")
159-
func = getattr(module, func_name)
160-
func.recompile()
161-
162-
(
163-
cmp_snippets_NOreassoc,
164-
cmp_indices_NOreassoc,
165-
cmp_profiles_NOreassoc,
166-
cmp_fractions_NOreassoc,
167-
cmp_areas_NOreassoc,
168-
cmp_regimes_NOreassoc,
169-
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
170-
171-
config._reset("STUMPY_FASTMATH_FLAGS")
172-
core._calculate_squared_distance.targetoptions["fastmath"] = (
173-
config.STUMPY_FASTMATH_FLAGS
174-
)
175-
for module_name, func_name in njit_funcs:
176-
module = importlib.import_module(f".{module_name}", package="stumpy")
177-
func = getattr(module, func_name)
178-
func.recompile()
179-
180-
if np.allclose(ref_snippets, cmp_snippets):
152+
if np.allclose(ref_snippets, cmp_snippets) or numba.config.DISABLE_JIT:
181153
npt.assert_almost_equal(
182154
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
183155
)
@@ -195,6 +167,41 @@ def test_snippets():
195167
)
196168
npt.assert_almost_equal(ref_regimes, cmp_regimes)
197169
else:
170+
# Revise fastmath flag, recompile, and re-calculate snippets,
171+
# and then revert the changes
172+
173+
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
174+
core._calculate_squared_distance.targetoptions["fastmath"] = (
175+
config.STUMPY_FASTMATH_FLAGS
176+
)
177+
178+
njit_funcs = cache.get_njit_funcs()
179+
for module_name, func_name in njit_funcs:
180+
module = importlib.import_module(f".{module_name}", package="stumpy")
181+
func = getattr(module, func_name)
182+
func.recompile()
183+
184+
(
185+
cmp_snippets_NOreassoc,
186+
cmp_indices_NOreassoc,
187+
cmp_profiles_NOreassoc,
188+
cmp_fractions_NOreassoc,
189+
cmp_areas_NOreassoc,
190+
cmp_regimes_NOreassoc,
191+
) = stumpy.snippets(
192+
T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func
193+
)
194+
195+
config._reset("STUMPY_FASTMATH_FLAGS")
196+
197+
core._calculate_squared_distance.targetoptions["fastmath"] = (
198+
config.STUMPY_FASTMATH_FLAGS
199+
)
200+
for module_name, func_name in njit_funcs:
201+
module = importlib.import_module(f".{module_name}", package="stumpy")
202+
func = getattr(module, func_name)
203+
func.recompile()
204+
198205
npt.assert_almost_equal(
199206
ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
200207
)

0 commit comments

Comments
 (0)