Skip to content

Commit f905d8c

Browse files
authored
Merge pull request #170 from asmeurer/numpy-rewrap
Re-enable wrapping unconditionally for NumPy
2 parents 158622e + 29afe3a commit f905d8c

File tree

3 files changed

+6
-12
lines changed

3 files changed

+6
-12
lines changed

.github/workflows/tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
strategy:
77
matrix:
88
python-version: ['3.9', '3.10', '3.11', '3.12']
9-
numpy-version: ['1.21', '1.26', 'dev']
9+
numpy-version: ['1.21', '1.26', '2.0', 'dev']
1010
exclude:
1111
- python-version: '3.11'
1212
numpy-version: '1.21'

array_api_compat/common/_helpers.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
145145

146146
import ndonnx as ndx
147147

148-
return isinstance(x, ndx.Array)
148+
return isinstance(x, ndx.Array)
149149

150150
def is_dask_array(x):
151151
"""
@@ -340,12 +340,9 @@ def your_function(x, y):
340340
elif use_compat is False:
341341
namespaces.add(np)
342342
else:
343-
# numpy 2.0 has __array_namespace__ and is fully array API
343+
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
344344
# compatible.
345-
if hasattr(np.empty(0), '__array_namespace__'):
346-
namespaces.add(np.empty(0).__array_namespace__(api_version=api_version))
347-
else:
348-
namespaces.add(numpy_namespace)
345+
namespaces.add(numpy_namespace)
349346
elif is_cupy_array(x):
350347
if _use_compat:
351348
_check_api_version(api_version)

tests/_helpers.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33

44
import pytest
55

6-
wrapped_libraries = ["cupy", "torch", "dask.array"]
7-
all_libraries = wrapped_libraries + ["numpy", "jax.numpy", "sparse"]
8-
import numpy as np
9-
if np.__version__[0] == '1':
10-
wrapped_libraries.append("numpy")
6+
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
7+
all_libraries = wrapped_libraries + ["jax.numpy"]
118

129
# `sparse` added array API support as of Python 3.10.
1310
if sys.version_info >= (3, 10):

0 commit comments

Comments
 (0)