@@ -202,7 +202,6 @@ def is_jax_array(x):
202
202
203
203
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204
204
205
-
206
205
def is_pydata_sparse_array (x ) -> bool :
207
206
"""
208
207
Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255
254
or is_pydata_sparse_array (x ) \
256
255
or hasattr (x , '__array_namespace__' )
257
256
257
+ def _compat_module_name ():
258
+ assert __name__ .endswith ('.common._helpers' )
259
+ return __name__ .removesuffix ('.common._helpers' )
260
+
261
+ def is_numpy_namespace (xp ) -> bool :
262
+ """
263
+ Returns True if `xp` is a NumPy namespace.
264
+
265
+ This includes both NumPy itself and the version wrapped by array-api-compat.
266
+
267
+ See Also
268
+ --------
269
+
270
+ array_namespace
271
+ is_cupy_namespace
272
+ is_torch_namespace
273
+ is_ndonnx_namespace
274
+ is_dask_namespace
275
+ is_jax_namespace
276
+ is_pydata_sparse_namespace
277
+ is_array_api_strict_namespace
278
+ """
279
+ return xp .__name__ in {'numpy' , _compat_module_name + '.numpy' }
280
+
281
+ def is_cupy_namespace (xp ) -> bool :
282
+ """
283
+ Returns True if `xp` is a CuPy namespace.
284
+
285
+ This includes both CuPy itself and the version wrapped by array-api-compat.
286
+
287
+ See Also
288
+ --------
289
+
290
+ array_namespace
291
+ is_numpy_namespace
292
+ is_torch_namespace
293
+ is_ndonnx_namespace
294
+ is_dask_namespace
295
+ is_jax_namespace
296
+ is_pydata_sparse_namespace
297
+ is_array_api_strict_namespace
298
+ """
299
+ return xp .__name__ in {'cupy' , _compat_module_name + '.cupy' }
300
+
301
+ def is_torch_namespace (xp ) -> bool :
302
+ """
303
+ Returns True if `xp` is a PyTorch namespace.
304
+
305
+ This includes both PyTorch itself and the version wrapped by array-api-compat.
306
+
307
+ See Also
308
+ --------
309
+
310
+ array_namespace
311
+ is_numpy_namespace
312
+ is_cupy_namespace
313
+ is_ndonnx_namespace
314
+ is_dask_namespace
315
+ is_jax_namespace
316
+ is_pydata_sparse_namespace
317
+ is_array_api_strict_namespace
318
+ """
319
+ return xp .__name__ in {'torch' , _compat_module_name + '.torch' }
320
+
321
+
322
+ def is_ndonnx_namespace (xp ):
323
+ """
324
+ Returns True if `xp` is an NDONNX namespace.
325
+
326
+ See Also
327
+ --------
328
+
329
+ array_namespace
330
+ is_numpy_namespace
331
+ is_cupy_namespace
332
+ is_torch_namespace
333
+ is_dask_namespace
334
+ is_jax_namespace
335
+ is_pydata_sparse_namespace
336
+ is_array_api_strict_namespace
337
+ """
338
+ return xp .__name__ == 'ndonnx'
339
+
340
+ def is_dask_namespace (xp ):
341
+ """
342
+ Returns True if `xp` is a Dask namespace.
343
+
344
+ This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345
+
346
+ See Also
347
+ --------
348
+
349
+ array_namespace
350
+ is_numpy_namespace
351
+ is_cupy_namespace
352
+ is_torch_namespace
353
+ is_ndonnx_namespace
354
+ is_jax_namespace
355
+ is_pydata_sparse_namespace
356
+ is_array_api_strict_namespace
357
+ """
358
+ return xp .__name__ in {'dask.array' , _compat_module_name + '.dask.array' }
359
+
360
+ def is_jax_namespace (xp ):
361
+ """
362
+ Returns True if `xp` is a JAX namespace.
363
+
364
+ This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365
+ older versions of JAX.
366
+
367
+ See Also
368
+ --------
369
+
370
+ array_namespace
371
+ is_numpy_namespace
372
+ is_cupy_namespace
373
+ is_torch_namespace
374
+ is_ndonnx_namespace
375
+ is_dask_namespace
376
+ is_pydata_sparse_namespace
377
+ is_array_api_strict_namespace
378
+ """
379
+ return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380
+
381
+ def is_pydata_sparse_namespace (xp ):
382
+ """
383
+ Returns True if `xp` is a pydata/sparse namespace.
384
+
385
+ See Also
386
+ --------
387
+
388
+ array_namespace
389
+ is_numpy_namespace
390
+ is_cupy_namespace
391
+ is_torch_namespace
392
+ is_ndonnx_namespace
393
+ is_dask_namespace
394
+ is_jax_namespace
395
+ is_array_api_strict_namespace
396
+ """
397
+ return xp .__name__ == 'sparse'
398
+
399
+ def is_array_api_strict_namespace (xp ):
400
+ """
401
+ Returns True if `xp` is an array-api-strict namespace.
402
+
403
+ See Also
404
+ --------
405
+
406
+ array_namespace
407
+ is_numpy_namespace
408
+ is_cupy_namespace
409
+ is_torch_namespace
410
+ is_ndonnx_namespace
411
+ is_dask_namespace
412
+ is_jax_namespace
413
+ is_pydata_sparse_namespace
414
+ """
415
+ return xp .__name__ == 'array_api_strict'
416
+
258
417
def _check_api_version (api_version ):
259
418
if api_version == '2021.12' :
260
419
warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
@@ -643,13 +802,21 @@ def size(x):
643
802
"device" ,
644
803
"get_namespace" ,
645
804
"is_array_api_obj" ,
805
+ "is_array_api_strict_namespace" ,
646
806
"is_cupy_array" ,
807
+ "is_cupy_namespace" ,
647
808
"is_dask_array" ,
809
+ "is_dask_namespace" ,
648
810
"is_jax_array" ,
811
+ "is_jax_namespace" ,
649
812
"is_numpy_array" ,
813
+ "is_numpy_namespace" ,
650
814
"is_torch_array" ,
815
+ "is_torch_namespace" ,
651
816
"is_ndonnx_array" ,
817
+ "is_ndonnx_namespace" ,
652
818
"is_pydata_sparse_array" ,
819
+ "is_pydata_sparse_namespace" ,
653
820
"size" ,
654
821
"to_device" ,
655
822
]
0 commit comments