Skip to content

Commit dd50583

Browse files
authored
Merge pull request #221 from larsoner/np2
MAINT: Test against latest NumPy
2 parents 696b834 + 3e65a12 commit dd50583

File tree

8 files changed

+55
-26
lines changed

8 files changed

+55
-26
lines changed

.github/workflows/test.yml

+6
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ jobs:
1919
test:
2020

2121
runs-on: ubuntu-latest
22+
continue-on-error: true
2223
strategy:
2324
matrix:
25+
# We test NumPy dev on 3.11
2426
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
2527
requires: ['requirements.txt']
2628
include:
@@ -37,9 +39,13 @@ jobs:
3739
allow-prereleases: true
3840
- name: Install
3941
run: |
42+
set -eo pipefail
4043
python -m pip install --upgrade pip
4144
python -m pip install -r ${{ matrix.requires }}
4245
python -m pip install -r requirements-dev.txt
46+
if [[ "${{ matrix.python-version }}" == "3.11" ]]; then
47+
python -m pip install --only-binary numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple "numpy>=2.1.0.dev0"
48+
fi
4349
python -m pip install .
4450
- name: Lint
4551
run: |

nitime/algorithms/event_related.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def fir(timeseries, design):
13-
"""
13+
r"""
1414
Calculate the FIR (finite impulse response) HRF, according to [Burock2000]_
1515
1616
Parameters

nitime/algorithms/tests/test_spectral.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ def test_mtm_lin_combo():
246246
mtm_cross = tsa.mtm_cross_spectrum(
247247
spec1, spec2, (weights[0], weights[1]), sides=sides
248248
)
249-
npt.assert_(mtm_cross.dtype in np.sctypes['complex'],
250-
'Wrong dtype for crossspectrum')
249+
assert mtm_cross.dtype == np.complex128, 'Wrong dtype for crossspectrum'
251250
npt.assert_(len(mtm_cross) == 51,
252251
'Wrong length for halfband spectrum')
253252
sides = 'twosided'
@@ -260,8 +259,7 @@ def test_mtm_lin_combo():
260259
mtm_auto = tsa.mtm_cross_spectrum(
261260
spec1, spec1, weights[0], sides=sides
262261
)
263-
npt.assert_(mtm_auto.dtype in np.sctypes['float'],
264-
'Wrong dtype for autospectrum')
262+
assert mtm_auto.dtype == np.float64, 'Wrong dtype for autospectrum'
265263
npt.assert_(len(mtm_auto) == 51,
266264
'Wrong length for halfband spectrum')
267265
sides = 'twosided'

nitime/index_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
'tril_indices_from', 'triu_indices', 'triu_indices_from',
77
]
88

9-
from numpy.core.numeric import asanyarray, subtract, arange, \
9+
from numpy import asanyarray, subtract, arange, \
1010
greater_equal, multiply, ones, asarray, where
1111

12-
# Need to import numpy for the doctests!
13-
import numpy as np
12+
# Need to import numpy for the doctests!
13+
import numpy as np
1414

1515
def tri(N, M=None, k=0, dtype=float):
1616
"""

nitime/tests/test_timeseries.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -916,16 +916,20 @@ def test_index_int64():
916916
assert repr(b[0]) == repr(b[np.int32(0)])
917917

918918

919-
def test_timearray_math_functions():
919+
@pytest.mark.parametrize('f', ['min', 'max', 'mean', 'ptp', 'sum'])
920+
@pytest.mark.parametrize('tu', ['s', 'ms', 'ps', 'D'])
921+
def test_timearray_math_functions(f, tu):
920922
"Calling TimeArray.min() .max(), mean() should return TimeArrays"
921923
a = np.arange(2, 11)
922-
for f in ['min', 'max', 'mean', 'ptp', 'sum']:
923-
for tu in ['s', 'ms', 'ps', 'D']:
924-
b = ts.TimeArray(a, time_unit=tu)
925-
npt.assert_(getattr(b, f)().__class__ == ts.TimeArray)
926-
npt.assert_(getattr(b, f)().time_unit == b.time_unit)
927-
# comparison with unitless should convert to the TimeArray's units
928-
npt.assert_(getattr(b, f)() == getattr(a, f)())
924+
b = ts.TimeArray(a, time_unit=tu)
925+
if f == "ptp" and ts._NP_2:
926+
want = np.ptp(a)
927+
else:
928+
want = getattr(a, f)()
929+
npt.assert_(getattr(b, f)().__class__ == ts.TimeArray)
930+
npt.assert_(getattr(b, f)().time_unit == b.time_unit)
931+
# comparison with unitless should convert to the TimeArray's units
932+
npt.assert_(getattr(b, f)() == want)
929933

930934

931935
def test_timearray_var_prod():

nitime/tests/test_utils.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def test_detect_lines():
230230
"""
231231
Tests detect_lines utility in the reliable low-SNR scenario.
232232
"""
233+
np.random.seed(0)
234+
233235
N = 1000
234236
fft_pow = int( np.ceil(np.log2(N) + 2) )
235237
NW = 4
@@ -286,19 +288,21 @@ def test_detect_lines_2dmode():
286288
Test multi-sequence operation
287289
"""
288290

291+
# This seed affects not just the signal we generate below, but then also
292+
# detect_lines->dpss_windows->tridi_inverse_iteration
293+
np.random.seed(0)
294+
289295
N = 1000
290296

291297
sig = np.cos( 2*np.pi*np.arange(N) * 20./N ) + np.random.randn(N) * .01
292298

293-
sig2d = np.row_stack( (sig, sig, sig) )
299+
sig2d = np.vstack( (sig, sig, sig) )
294300

295301
lines = utils.detect_lines(sig2d, (4, 8), low_bias=True, NFFT=2**12)
296302

297303
npt.assert_(len(lines)==3, 'Detect lines failed multi-sequence mode')
298304

299-
consistent1 = (lines[0][0] == lines[1][0]).all() and \
300-
(lines[1][0] == lines[2][0]).all()
301-
consistent2 = (lines[0][1] == lines[1][1]).all() and \
302-
(lines[1][1] == lines[2][1]).all()
303-
304-
npt.assert_(consistent1 and consistent2, 'Inconsistent results')
305+
npt.assert_allclose(lines[0][0], lines[1][0])
306+
npt.assert_allclose(lines[0][0], lines[2][0])
307+
npt.assert_allclose(lines[0][1], lines[1][1])
308+
npt.assert_allclose(lines[0][1], lines[2][1])

nitime/timeseries.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
# Our own
3434
from nitime import descriptors as desc
3535

36+
try:
37+
_NP_2 = int(np.__version__.split(".")[0]) >= 2
38+
except Exception:
39+
_NP_2 = True
40+
3641
#-----------------------------------------------------------------------------
3742
# Module globals
3843
#-----------------------------------------------------------------------------
@@ -112,7 +117,9 @@ def __new__(cls, data, time_unit=None, copy=True):
112117
which are SI units of time. Default: 's'
113118
114119
copy : bool, optional
115-
Whether to create this instance by copy of a
120+
Whether to create this instance by copy of a. If False,
121+
a copy will not be forced but might still be required depending
122+
on the data array.
116123
117124
Note
118125
----
@@ -152,7 +159,7 @@ class instance, or an int64 array in the base unit of the module
152159
e_s += 'TimeArray in object, or int64 times, in %s' % base_unit
153160
raise ValueError(e_s)
154161

155-
time = np.array(data, copy=False)
162+
time = np.asarray(data)
156163
else:
157164
if isinstance(data, TimeInterface):
158165
time = data.copy()
@@ -309,7 +316,11 @@ def mean(self, *args, **kwargs):
309316
return ret
310317

311318
def ptp(self, *args, **kwargs):
312-
ret = TimeArray(np.ndarray.ptp(self, *args, **kwargs),
319+
if _NP_2:
320+
ptp = np.ptp
321+
else:
322+
ptp = np.ndarray.ptp
323+
ret = TimeArray(ptp(self, *args, **kwargs),
313324
time_unit=base_unit)
314325
ret.convert_unit(self.time_unit)
315326
return ret

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ skip = "pp* cp38-*_aarch64 cp38-musllinux_*"
6262
# don't bother unless someone asks
6363
archs = ["native"]
6464

65+
test-requires = [
66+
"pytest",
67+
"nitime[full]", # Enable all optional behavior
68+
]
69+
test-command = "pytest -rsx --pyargs nitime"
70+
6571
[tool.cibuildwheel.linux]
6672
archs = ["x86_64", "aarch64"]
6773

0 commit comments

Comments
 (0)