@@ -202,7 +202,6 @@ def is_jax_array(x):
202
202
203
203
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204
204
205
-
206
205
def is_pydata_sparse_array (x ) -> bool :
207
206
"""
208
207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,161 @@ def is_array_api_obj(x):
255
254
or is_pydata_sparse_array (x ) \
256
255
or hasattr (x , '__array_namespace__' )
257
256
257
+ def is_numpy_namespace (xp ) -> bool :
258
+ """
259
+ Returns True if `xp` is a NumPy namespace.
260
+
261
+ This includes both NumPy itself and the version wrapped by array-api-compat.
262
+
263
+ See Also
264
+ --------
265
+
266
+ array_namespace
267
+ is_cupy_namespace
268
+ is_torch_namespace
269
+ is_ndonnx_namespace
270
+ is_dask_namespace
271
+ is_jax_namespace
272
+ is_pydata_sparse_namespace
273
+ is_array_api_strict_namespace
274
+ """
275
+ return xp .__name__ == 'numpy' or 'array_api_compat.numpy' in xp .__name__
276
+
277
+ def is_cupy_namespace (xp ) -> bool :
278
+ """
279
+ Returns True if `xp` is a CuPy namespace.
280
+
281
+ This includes both CuPy itself and the version wrapped by array-api-compat.
282
+
283
+ See Also
284
+ --------
285
+
286
+ array_namespace
287
+ is_numpy_namespace
288
+ is_torch_namespace
289
+ is_ndonnx_namespace
290
+ is_dask_namespace
291
+ is_jax_namespace
292
+ is_pydata_sparse_namespace
293
+ is_array_api_strict_namespace
294
+ """
295
+ return xp .__name__ == 'cupy' or 'array_api_compat.cupy' in xp .__name__
296
+
297
+ def is_torch_namespace (xp ) -> bool :
298
+ """
299
+ Returns True if `xp` is a PyTorch namespace.
300
+
301
+ This includes both PyTorch itself and the version wrapped by array-api-compat.
302
+
303
+ See Also
304
+ --------
305
+
306
+ array_namespace
307
+ is_numpy_namespace
308
+ is_cupy_namespace
309
+ is_ndonnx_namespace
310
+ is_dask_namespace
311
+ is_jax_namespace
312
+ is_pydata_sparse_namespace
313
+ is_array_api_strict_namespace
314
+ """
315
+ return xp .__name__ == 'torch' or 'array_api_compat.torch' in xp .__name__
316
+
317
+ def is_ndonnx_namespace (xp ):
318
+ """
319
+ Returns True if `xp` is an NDONNX namespace.
320
+
321
+ See Also
322
+ --------
323
+
324
+ array_namespace
325
+ is_numpy_namespace
326
+ is_cupy_namespace
327
+ is_torch_namespace
328
+ is_dask_namespace
329
+ is_jax_namespace
330
+ is_pydata_sparse_namespace
331
+ is_array_api_strict_namespace
332
+ """
333
+ return xp .__name__ == 'ndonnx'
334
+
335
+ def is_dask_namespace (xp ):
336
+ """
337
+ Returns True if `xp` is a Dask namespace.
338
+
339
+ This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
340
+
341
+ See Also
342
+ --------
343
+
344
+ array_namespace
345
+ is_numpy_namespace
346
+ is_cupy_namespace
347
+ is_torch_namespace
348
+ is_ndonnx_namespace
349
+ is_jax_namespace
350
+ is_pydata_sparse_namespace
351
+ is_array_api_strict_namespace
352
+ """
353
+ return xp .__name__ == 'dask.array' or 'array_api_compat.dask.array' in xp .__name__
354
+
355
+ def is_jax_namespace (xp ):
356
+ """
357
+ Returns True if `xp` is a JAX namespace.
358
+
359
+ This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
360
+ older versions of JAX.
361
+
362
+ See Also
363
+ --------
364
+
365
+ array_namespace
366
+ is_numpy_namespace
367
+ is_cupy_namespace
368
+ is_torch_namespace
369
+ is_ndonnx_namespace
370
+ is_dask_namespace
371
+ is_pydata_sparse_namespace
372
+ is_array_api_strict_namespace
373
+ """
374
+ return xp .__name__ in ('jax.numpy' , 'jax.experimental.array_api' )
375
+
376
+ def is_pydata_sparse_namespace (xp ):
377
+ """
378
+ Returns True if `xp` is a pydata/sparse namespace.
379
+
380
+ See Also
381
+ --------
382
+
383
+ array_namespace
384
+ is_numpy_namespace
385
+ is_cupy_namespace
386
+ is_torch_namespace
387
+ is_ndonnx_namespace
388
+ is_dask_namespace
389
+ is_jax_namespace
390
+ is_array_api_strict_namespace
391
+ """
392
+ return xp .__name__ == 'sparse'
393
+
394
+ def is_array_api_strict_namespace (xp ):
395
+ """
396
+ Returns True if `xp` is an array-api-strict namespace.
397
+
398
+ See Also
399
+ --------
400
+
401
+ array_namespace
402
+ is_numpy_namespace
403
+ is_cupy_namespace
404
+ is_torch_namespace
405
+ is_ndonnx_namespace
406
+ is_dask_namespace
407
+ is_jax_namespace
408
+ is_pydata_sparse_namespace
409
+ """
410
+ return xp .__name__ == 'array_api_strict'
411
+
258
412
def _check_api_version (api_version ):
259
413
if api_version == '2021.12' :
260
414
warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
0 commit comments