3
3
from unittest .mock import patch
4
4
5
5
import naive
6
+ import numba
6
7
import numpy as np
7
8
import numpy .testing as npt
8
9
import pytest
@@ -138,6 +139,7 @@ def test_snippets():
138
139
) = naive .mpdist_snippets (
139
140
T , m , k , s = s , mpdist_T_subseq_isconstant = isconstant_custom_func
140
141
)
142
+
141
143
(
142
144
cmp_snippets ,
143
145
cmp_indices ,
@@ -147,37 +149,7 @@ def test_snippets():
147
149
cmp_regimes ,
148
150
) = stumpy .snippets (T , m , k , s = s , mpdist_T_subseq_isconstant = isconstant_custom_func )
149
151
150
- # Revise fastmath flag, recompile, and re-calculate snippets,
151
- # and then revert the changes
152
- config .STUMPY_FASTMATH_FLAGS = {"nsz" , "arcp" , "contract" , "afn" }
153
- core ._calculate_squared_distance .targetoptions ["fastmath" ] = (
154
- config .STUMPY_FASTMATH_FLAGS
155
- )
156
- njit_funcs = cache .get_njit_funcs ()
157
- for module_name , func_name in njit_funcs :
158
- module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
159
- func = getattr (module , func_name )
160
- func .recompile ()
161
-
162
- (
163
- cmp_snippets_NOreassoc ,
164
- cmp_indices_NOreassoc ,
165
- cmp_profiles_NOreassoc ,
166
- cmp_fractions_NOreassoc ,
167
- cmp_areas_NOreassoc ,
168
- cmp_regimes_NOreassoc ,
169
- ) = stumpy .snippets (T , m , k , s = s , mpdist_T_subseq_isconstant = isconstant_custom_func )
170
-
171
- config ._reset ("STUMPY_FASTMATH_FLAGS" )
172
- core ._calculate_squared_distance .targetoptions ["fastmath" ] = (
173
- config .STUMPY_FASTMATH_FLAGS
174
- )
175
- for module_name , func_name in njit_funcs :
176
- module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
177
- func = getattr (module , func_name )
178
- func .recompile ()
179
-
180
- if np .allclose (ref_snippets , cmp_snippets ):
152
+ if np .allclose (ref_snippets , cmp_snippets ) or numba .config .DISABLE_JIT :
181
153
npt .assert_almost_equal (
182
154
ref_snippets , cmp_snippets , decimal = config .STUMPY_TEST_PRECISION
183
155
)
@@ -195,6 +167,41 @@ def test_snippets():
195
167
)
196
168
npt .assert_almost_equal (ref_regimes , cmp_regimes )
197
169
else :
170
+ # Revise fastmath flag, recompile, and re-calculate snippets,
171
+ # and then revert the changes
172
+
173
+ config .STUMPY_FASTMATH_FLAGS = {"nsz" , "arcp" , "contract" , "afn" }
174
+ core ._calculate_squared_distance .targetoptions ["fastmath" ] = (
175
+ config .STUMPY_FASTMATH_FLAGS
176
+ )
177
+
178
+ njit_funcs = cache .get_njit_funcs ()
179
+ for module_name , func_name in njit_funcs :
180
+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
181
+ func = getattr (module , func_name )
182
+ func .recompile ()
183
+
184
+ (
185
+ cmp_snippets_NOreassoc ,
186
+ cmp_indices_NOreassoc ,
187
+ cmp_profiles_NOreassoc ,
188
+ cmp_fractions_NOreassoc ,
189
+ cmp_areas_NOreassoc ,
190
+ cmp_regimes_NOreassoc ,
191
+ ) = stumpy .snippets (
192
+ T , m , k , s = s , mpdist_T_subseq_isconstant = isconstant_custom_func
193
+ )
194
+
195
+ config ._reset ("STUMPY_FASTMATH_FLAGS" )
196
+
197
+ core ._calculate_squared_distance .targetoptions ["fastmath" ] = (
198
+ config .STUMPY_FASTMATH_FLAGS
199
+ )
200
+ for module_name , func_name in njit_funcs :
201
+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
202
+ func = getattr (module , func_name )
203
+ func .recompile ()
204
+
198
205
npt .assert_almost_equal (
199
206
ref_snippets , cmp_snippets_NOreassoc , decimal = config .STUMPY_TEST_PRECISION
200
207
)
0 commit comments