Skip to content

Commit 67f3c9f

Browse files
committed
add is_*_namespace helper functions
closes gh-156
1 parent f3145b0 commit 67f3c9f

File tree

1 file changed

+163
-1
lines changed

1 file changed

+163
-1
lines changed

array_api_compat/common/_helpers.py

+163-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def is_jax_array(x):
202202

203203
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204204

205-
206205
def is_pydata_sparse_array(x) -> bool:
207206
"""
208207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,161 @@ def is_array_api_obj(x):
255254
or is_pydata_sparse_array(x) \
256255
or hasattr(x, '__array_namespace__')
257256

257+
def is_numpy_namespace(xp) -> bool:
258+
"""
259+
Returns True if `xp` is a NumPy namespace.
260+
261+
This includes both NumPy itself and the version wrapped by array-api-compat.
262+
263+
See Also
264+
--------
265+
266+
array_namespace
267+
is_cupy_namespace
268+
is_torch_namespace
269+
is_ndonnx_namespace
270+
is_dask_namespace
271+
is_jax_namespace
272+
is_pydata_sparse_namespace
273+
is_array_api_strict_namespace
274+
"""
275+
return xp.__name__ == 'numpy' or 'array_api_compat.numpy' in xp.__name__
276+
277+
def is_cupy_namespace(xp) -> bool:
278+
"""
279+
Returns True if `xp` is a CuPy namespace.
280+
281+
This includes both CuPy itself and the version wrapped by array-api-compat.
282+
283+
See Also
284+
--------
285+
286+
array_namespace
287+
is_numpy_namespace
288+
is_torch_namespace
289+
is_ndonnx_namespace
290+
is_dask_namespace
291+
is_jax_namespace
292+
is_pydata_sparse_namespace
293+
is_array_api_strict_namespace
294+
"""
295+
return xp.__name__ == 'cupy' or 'array_api_compat.cupy' in xp.__name__
296+
297+
def is_torch_namespace(xp) -> bool:
298+
"""
299+
Returns True if `xp` is a PyTorch namespace.
300+
301+
This includes both PyTorch itself and the version wrapped by array-api-compat.
302+
303+
See Also
304+
--------
305+
306+
array_namespace
307+
is_numpy_namespace
308+
is_cupy_namespace
309+
is_ndonnx_namespace
310+
is_dask_namespace
311+
is_jax_namespace
312+
is_pydata_sparse_namespace
313+
is_array_api_strict_namespace
314+
"""
315+
return xp.__name__ == 'torch' or 'array_api_compat.torch' in xp.__name__
316+
317+
def is_ndonnx_namespace(xp):
318+
"""
319+
Returns True if `xp` is an NDONNX namespace.
320+
321+
See Also
322+
--------
323+
324+
array_namespace
325+
is_numpy_namespace
326+
is_cupy_namespace
327+
is_torch_namespace
328+
is_dask_namespace
329+
is_jax_namespace
330+
is_pydata_sparse_namespace
331+
is_array_api_strict_namespace
332+
"""
333+
return xp.__name__ == 'ndonnx'
334+
335+
def is_dask_namespace(xp):
336+
"""
337+
Returns True if `xp` is a Dask namespace.
338+
339+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
340+
341+
See Also
342+
--------
343+
344+
array_namespace
345+
is_numpy_namespace
346+
is_cupy_namespace
347+
is_torch_namespace
348+
is_ndonnx_namespace
349+
is_jax_namespace
350+
is_pydata_sparse_namespace
351+
is_array_api_strict_namespace
352+
"""
353+
return xp.__name__ == 'dask.array' or 'array_api_compat.dask.array' in xp.__name__
354+
355+
def is_jax_namespace(xp):
356+
"""
357+
Returns True if `xp` is a JAX namespace.
358+
359+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
360+
older versions of JAX.
361+
362+
See Also
363+
--------
364+
365+
array_namespace
366+
is_numpy_namespace
367+
is_cupy_namespace
368+
is_torch_namespace
369+
is_ndonnx_namespace
370+
is_dask_namespace
371+
is_pydata_sparse_namespace
372+
is_array_api_strict_namespace
373+
"""
374+
return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api')
375+
376+
def is_pydata_sparse_namespace(xp):
377+
"""
378+
Returns True if `xp` is a pydata/sparse namespace.
379+
380+
See Also
381+
--------
382+
383+
array_namespace
384+
is_numpy_namespace
385+
is_cupy_namespace
386+
is_torch_namespace
387+
is_ndonnx_namespace
388+
is_dask_namespace
389+
is_jax_namespace
390+
is_array_api_strict_namespace
391+
"""
392+
return xp.__name__ == 'sparse'
393+
394+
def is_array_api_strict_namespace(xp):
395+
"""
396+
Returns True if `xp` is an array-api-strict namespace.
397+
398+
See Also
399+
--------
400+
401+
array_namespace
402+
is_numpy_namespace
403+
is_cupy_namespace
404+
is_torch_namespace
405+
is_ndonnx_namespace
406+
is_dask_namespace
407+
is_jax_namespace
408+
is_pydata_sparse_namespace
409+
"""
410+
return xp.__name__ == 'array_api_strict'
411+
258412
def _check_api_version(api_version):
259413
if api_version == '2021.12':
260414
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
@@ -643,13 +797,21 @@ def size(x):
643797
"device",
644798
"get_namespace",
645799
"is_array_api_obj",
800+
"is_array_api_strict_namespace",
646801
"is_cupy_array",
802+
"is_cupy_namespace",
647803
"is_dask_array",
804+
"is_dask_namespace",
648805
"is_jax_array",
806+
"is_jax_namespace",
649807
"is_numpy_array",
808+
"is_numpy_namespace",
650809
"is_torch_array",
810+
"is_torch_namespace",
651811
"is_ndonnx_array",
812+
"is_ndonnx_namespace",
652813
"is_pydata_sparse_array",
814+
"is_pydata_sparse_namespace",
653815
"size",
654816
"to_device",
655817
]

0 commit comments

Comments
 (0)