Skip to content

Commit 733d17c

Browse files
committed
add is_*_namespace helper functions
closes gh-156
1 parent f3145b0 commit 733d17c

File tree

1 file changed

+168
-1
lines changed

1 file changed

+168
-1
lines changed

array_api_compat/common/_helpers.py

+168-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def is_jax_array(x):
202202

203203
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204204

205-
206205
def is_pydata_sparse_array(x) -> bool:
207206
"""
208207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255254
or is_pydata_sparse_array(x) \
256255
or hasattr(x, '__array_namespace__')
257256

257+
def _compat_module_name():
258+
assert __name__.endswith('.common._helpers')
259+
return __name__.removesuffix('.common._helpers')
260+
261+
def is_numpy_namespace(xp) -> bool:
262+
"""
263+
Returns True if `xp` is a NumPy namespace.
264+
265+
This includes both NumPy itself and the version wrapped by array-api-compat.
266+
267+
See Also
268+
--------
269+
270+
array_namespace
271+
is_cupy_namespace
272+
is_torch_namespace
273+
is_ndonnx_namespace
274+
is_dask_namespace
275+
is_jax_namespace
276+
is_pydata_sparse_namespace
277+
is_array_api_strict_namespace
278+
"""
279+
return xp.__name__ in {'numpy', _compat_module_name + '.numpy'}
280+
281+
def is_cupy_namespace(xp) -> bool:
282+
"""
283+
Returns True if `xp` is a CuPy namespace.
284+
285+
This includes both CuPy itself and the version wrapped by array-api-compat.
286+
287+
See Also
288+
--------
289+
290+
array_namespace
291+
is_numpy_namespace
292+
is_torch_namespace
293+
is_ndonnx_namespace
294+
is_dask_namespace
295+
is_jax_namespace
296+
is_pydata_sparse_namespace
297+
is_array_api_strict_namespace
298+
"""
299+
return xp.__name__ in {'cupy', _compat_module_name + '.cupy'}
300+
301+
def is_torch_namespace(xp) -> bool:
302+
"""
303+
Returns True if `xp` is a PyTorch namespace.
304+
305+
This includes both PyTorch itself and the version wrapped by array-api-compat.
306+
307+
See Also
308+
--------
309+
310+
array_namespace
311+
is_numpy_namespace
312+
is_cupy_namespace
313+
is_ndonnx_namespace
314+
is_dask_namespace
315+
is_jax_namespace
316+
is_pydata_sparse_namespace
317+
is_array_api_strict_namespace
318+
"""
319+
return xp.__name__ in {'torch', _compat_module_name + '.torch'}
320+
321+
322+
def is_ndonnx_namespace(xp):
323+
"""
324+
Returns True if `xp` is an NDONNX namespace.
325+
326+
See Also
327+
--------
328+
329+
array_namespace
330+
is_numpy_namespace
331+
is_cupy_namespace
332+
is_torch_namespace
333+
is_dask_namespace
334+
is_jax_namespace
335+
is_pydata_sparse_namespace
336+
is_array_api_strict_namespace
337+
"""
338+
return xp.__name__ == 'ndonnx'
339+
340+
def is_dask_namespace(xp):
341+
"""
342+
Returns True if `xp` is a Dask namespace.
343+
344+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345+
346+
See Also
347+
--------
348+
349+
array_namespace
350+
is_numpy_namespace
351+
is_cupy_namespace
352+
is_torch_namespace
353+
is_ndonnx_namespace
354+
is_jax_namespace
355+
is_pydata_sparse_namespace
356+
is_array_api_strict_namespace
357+
"""
358+
return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'}
359+
360+
def is_jax_namespace(xp):
361+
"""
362+
Returns True if `xp` is a JAX namespace.
363+
364+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365+
older versions of JAX.
366+
367+
See Also
368+
--------
369+
370+
array_namespace
371+
is_numpy_namespace
372+
is_cupy_namespace
373+
is_torch_namespace
374+
is_ndonnx_namespace
375+
is_dask_namespace
376+
is_pydata_sparse_namespace
377+
is_array_api_strict_namespace
378+
"""
379+
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380+
381+
def is_pydata_sparse_namespace(xp):
382+
"""
383+
Returns True if `xp` is a pydata/sparse namespace.
384+
385+
See Also
386+
--------
387+
388+
array_namespace
389+
is_numpy_namespace
390+
is_cupy_namespace
391+
is_torch_namespace
392+
is_ndonnx_namespace
393+
is_dask_namespace
394+
is_jax_namespace
395+
is_array_api_strict_namespace
396+
"""
397+
return xp.__name__ == 'sparse'
398+
399+
def is_array_api_strict_namespace(xp):
400+
"""
401+
Returns True if `xp` is an array-api-strict namespace.
402+
403+
See Also
404+
--------
405+
406+
array_namespace
407+
is_numpy_namespace
408+
is_cupy_namespace
409+
is_torch_namespace
410+
is_ndonnx_namespace
411+
is_dask_namespace
412+
is_jax_namespace
413+
is_pydata_sparse_namespace
414+
"""
415+
return xp.__name__ == 'array_api_strict'
416+
258417
def _check_api_version(api_version):
259418
if api_version == '2021.12':
260419
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
@@ -643,13 +802,21 @@ def size(x):
643802
"device",
644803
"get_namespace",
645804
"is_array_api_obj",
805+
"is_array_api_strict_namespace",
646806
"is_cupy_array",
807+
"is_cupy_namespace",
647808
"is_dask_array",
809+
"is_dask_namespace",
648810
"is_jax_array",
811+
"is_jax_namespace",
649812
"is_numpy_array",
813+
"is_numpy_namespace",
650814
"is_torch_array",
815+
"is_torch_namespace",
651816
"is_ndonnx_array",
817+
"is_ndonnx_namespace",
652818
"is_pydata_sparse_array",
819+
"is_pydata_sparse_namespace",
653820
"size",
654821
"to_device",
655822
]

0 commit comments

Comments
 (0)