Skip to content

Commit 2d30901

Browse files
committed
Don't bypass numpy compat just because it has __array_namespace__
data-apis#77 (comment)
1 parent d235910 commit 2d30901

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

array_api_compat/common/_helpers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def your_function(x, y):
7171
"""
7272
namespaces = set()
7373
for x in xs:
74-
if hasattr(x, '__array_namespace__'):
75-
namespaces.add(x.__array_namespace__(api_version=api_version))
76-
elif _is_numpy_array(x):
74+
if _is_numpy_array(x):
7775
_check_api_version(api_version)
7876
if _use_compat:
7977
from .. import numpy as numpy_namespace
@@ -97,6 +95,8 @@ def your_function(x, y):
9795
else:
9896
import torch
9997
namespaces.add(torch)
98+
elif hasattr(x, '__array_namespace__'):
99+
namespaces.add(x.__array_namespace__(api_version=api_version))
100100
else:
101101
# TODO: Support Python scalars?
102102
raise TypeError(f"{type(x).__name__} is not a supported array type")

0 commit comments

Comments
 (0)