Skip to content

add is_*_namespace helper functions #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -255,6 +254,166 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')

def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.

This includes both NumPy itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'numpy', _compat_module_name + '.numpy'}

def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.

This includes both CuPy itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'cupy', _compat_module_name + '.cupy'}

def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.

This includes both PyTorch itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name + '.torch'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):
"""
Returns True if `xp` is a Dask namespace.

This includes both ``dask.array`` itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'}

def is_jax_namespace(xp):
"""
Returns True if `xp` is a JAX namespace.

This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
older versions of JAX.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):
"""
Returns True if `xp` is a pydata/sparse namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):
"""
Returns True if `xp` is an array-api-strict namespace.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
Expand Down Expand Up @@ -643,13 +802,21 @@ def size(x):
"device",
"get_namespace",
"is_array_api_obj",
"is_array_api_strict_namespace",
"is_cupy_array",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_array",
"is_numpy_namespace",
"is_torch_array",
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
"to_device",
]
Expand Down
Loading