|
1 | 1 | import functools
|
| 2 | +import importlib |
2 | 3 | from unittest.mock import patch
|
3 | 4 |
|
4 | 5 | import naive
|
|
8 | 9 | from numba import cuda
|
9 | 10 |
|
10 | 11 | import stumpy
|
11 |
| -from stumpy import config, core |
| 12 | +from stumpy import cache, config, core |
12 | 13 |
|
13 | 14 | try:
|
14 | 15 | from numba.errors import NumbaPerformanceWarning
|
@@ -146,20 +147,70 @@ def test_snippets():
|
146 | 147 | cmp_regimes,
|
147 | 148 | ) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
|
148 | 149 |
|
149 |
| - npt.assert_almost_equal( |
150 |
| - ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION |
151 |
| - ) |
152 |
| - npt.assert_almost_equal( |
153 |
| - ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION |
| 150 | + # Revise fastmath flag, recompile, and re-calculate snippets, |
| 151 | + # and then revert the changes |
| 152 | + config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"} |
| 153 | + core._calculate_squared_distance.targetoptions["fastmath"] = ( |
| 154 | + config.STUMPY_FASTMATH_FLAGS |
154 | 155 | )
|
155 |
| - npt.assert_almost_equal( |
156 |
| - ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION |
157 |
| - ) |
158 |
| - npt.assert_almost_equal( |
159 |
| - ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION |
| 156 | + njit_funcs = cache.get_njit_funcs() |
| 157 | + for module_name, func_name in njit_funcs: |
| 158 | + module = importlib.import_module(f".{module_name}", package="stumpy") |
| 159 | + func = getattr(module, func_name) |
| 160 | + func.recompile() |
| 161 | + |
| 162 | + ( |
| 163 | + cmp_snippets_NOreassoc, |
| 164 | + cmp_indices_NOreassoc, |
| 165 | + cmp_profiles_NOreassoc, |
| 166 | + cmp_fractions_NOreassoc, |
| 167 | + cmp_areas_NOreassoc, |
| 168 | + cmp_regimes_NOreassoc, |
| 169 | + ) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func) |
| 170 | + |
| 171 | + config._reset("STUMPY_FASTMATH_FLAGS") |
| 172 | + core._calculate_squared_distance.targetoptions["fastmath"] = ( |
| 173 | + config.STUMPY_FASTMATH_FLAGS |
160 | 174 | )
|
161 |
| - npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION) |
162 |
| - npt.assert_almost_equal(ref_regimes, cmp_regimes) |
| 175 | + for module_name, func_name in njit_funcs: |
| 176 | + module = importlib.import_module(f".{module_name}", package="stumpy") |
| 177 | + func = getattr(module, func_name) |
| 178 | + func.recompile() |
| 179 | + |
| 180 | + if np.allclose(ref_snippets, cmp_snippets): |
| 181 | + npt.assert_almost_equal( |
| 182 | + ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION |
| 183 | + ) |
| 184 | + npt.assert_almost_equal( |
| 185 | + ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION |
| 186 | + ) |
| 187 | + npt.assert_almost_equal( |
| 188 | + ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION |
| 189 | + ) |
| 190 | + npt.assert_almost_equal( |
| 191 | + ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION |
| 192 | + ) |
| 193 | + npt.assert_almost_equal( |
| 194 | + ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION |
| 195 | + ) |
| 196 | + npt.assert_almost_equal(ref_regimes, cmp_regimes) |
| 197 | + else: |
| 198 | + npt.assert_almost_equal( |
| 199 | + ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 200 | + ) |
| 201 | + npt.assert_almost_equal( |
| 202 | + ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 203 | + ) |
| 204 | + npt.assert_almost_equal( |
| 205 | + ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 206 | + ) |
| 207 | + npt.assert_almost_equal( |
| 208 | + ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 209 | + ) |
| 210 | + npt.assert_almost_equal( |
| 211 | + ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION |
| 212 | + ) |
| 213 | + npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc) |
163 | 214 |
|
164 | 215 |
|
165 | 216 | @pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)
|
|
0 commit comments