Skip to content

Commit d43e117

Browse files
committed
BUG: fix array_namespace for NumPy scalars
1 parent ff87838 commit d43e117

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

Diff for: array_api_compat/common/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def your_function(x, y):
342342
else:
343343
# numpy 2.0 has __array_namespace__ and is fully array API
344344
# compatible.
345-
if hasattr(x, '__array_namespace__'):
345+
if hasattr(np.empty(0), '__array_namespace__'):
346346
namespaces.add(x.__array_namespace__(api_version=api_version))
347347
else:
348348
namespaces.add(numpy_namespace)

Diff for: tests/test_array_namespace.py

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ def test_array_namespace(library, api_version, use_compat):
3636
else:
3737
assert namespace == getattr(array_api_compat, library)
3838

39+
if library == "numpy":
40+
# check that the same library is returned for NumPy scalars
41+
scalar_namespace = array_api_compat.array_namespace(
42+
xp.float64(0.0), api_version=api_version, use_compat=use_compat
43+
)
44+
assert scalar_namespace == namespace
45+
3946
# Check that array_namespace works even if jax.experimental.array_api
4047
# hasn't been imported yet (it monkeypatches __array_namespace__
4148
# onto JAX arrays, but we should support them regardless). The only way to

0 commit comments

Comments
 (0)