diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 32fb0e70..bafe991a 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,7 +235,8 @@ def array_namespace(*xs, api_version=None, use_compat=None): Parameters ---------- xs: arrays - one or more arrays. + one or more arrays. xs can also be Python scalars (bool, int, float, + complex, or None), which are ignored. api_version: str The newest version of the spec that you need support for (currently @@ -298,7 +299,9 @@ def your_function(x, y): namespaces = set() for x in xs: - if is_numpy_array(x): + if isinstance(x, (bool, int, float, complex, type(None))): + continue + elif is_numpy_array(x): from .. import numpy as numpy_namespace import numpy as np if use_compat is True: diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1f83a473..af0ac244 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -92,3 +92,19 @@ def test_api_version(): def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_api_compat.array_namespace + +def test_python_scalars(): + a = torch.asarray([1, 2]) + xp = import_("torch", wrapper=True) + + pytest.raises(TypeError, lambda: array_namespace(1)) + pytest.raises(TypeError, lambda: array_namespace(1.0)) + pytest.raises(TypeError, lambda: array_namespace(1j)) + pytest.raises(TypeError, lambda: array_namespace(True)) + pytest.raises(TypeError, lambda: array_namespace(None)) + + assert array_namespace(a, 1) == xp + assert array_namespace(a, 1.0) == xp + assert array_namespace(a, 1j) == xp + assert array_namespace(a, True) == xp + assert array_namespace(a, None) == xp