Skip to content

Commit 7ec0b82

Browse files
authored
Merge pull request #128 from crusaderky/compat_all
2 parents f9b4dcf + 6997c91 commit 7ec0b82

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

src/array_api_extra/_lib/_utils/_compat.py

+15
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
from ...._array_api_compat_vendor import (
77
array_namespace,
88
device,
9+
is_array_api_obj,
910
is_array_api_strict_namespace,
11+
is_cupy_array,
1012
is_cupy_namespace,
1113
is_dask_array,
1214
is_dask_namespace,
1315
is_jax_array,
1416
is_jax_namespace,
17+
is_numpy_array,
1518
is_numpy_namespace,
19+
is_pydata_sparse_array,
1620
is_pydata_sparse_namespace,
21+
is_torch_array,
1722
is_torch_namespace,
1823
is_writeable_array,
1924
size,
@@ -22,14 +27,19 @@
2227
from array_api_compat import (
2328
array_namespace,
2429
device,
30+
is_array_api_obj,
2531
is_array_api_strict_namespace,
32+
is_cupy_array,
2633
is_cupy_namespace,
2734
is_dask_array,
2835
is_dask_namespace,
2936
is_jax_array,
3037
is_jax_namespace,
38+
is_numpy_array,
3139
is_numpy_namespace,
40+
is_pydata_sparse_array,
3241
is_pydata_sparse_namespace,
42+
is_torch_array,
3343
is_torch_namespace,
3444
is_writeable_array,
3545
size,
@@ -38,14 +48,19 @@
3848
__all__ = [
3949
"array_namespace",
4050
"device",
51+
"is_array_api_obj",
4152
"is_array_api_strict_namespace",
53+
"is_cupy_array",
4254
"is_cupy_namespace",
4355
"is_dask_array",
4456
"is_dask_namespace",
4557
"is_jax_array",
4658
"is_jax_namespace",
59+
"is_numpy_array",
4760
"is_numpy_namespace",
61+
"is_pydata_sparse_array",
4862
"is_pydata_sparse_namespace",
63+
"is_torch_array",
4964
"is_torch_namespace",
5065
"is_writeable_array",
5166
"size",

src/array_api_extra/_lib/_utils/_compat.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@ def array_namespace(
1818
use_compat: bool | None = None,
1919
) -> ArrayModule: ...
2020
def device(x: Array, /) -> Device: ...
21+
def is_array_api_obj(x: object, /) -> bool: ...
2122
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
2223
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
2324
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
2425
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2526
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
2627
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2728
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
29+
def is_cupy_array(x: object, /) -> bool: ...
2830
def is_dask_array(x: object, /) -> bool: ...
2931
def is_jax_array(x: object, /) -> bool: ...
32+
def is_numpy_array(x: object, /) -> bool: ...
33+
def is_pydata_sparse_array(x: object, /) -> bool: ...
34+
def is_torch_array(x: object, /) -> bool: ...
3035
def is_writeable_array(x: object, /) -> bool: ...
3136
def size(x: Array, /) -> int | None: ...

vendor_tests/test_vendor.py

+14
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@ def test_vendor_compat():
66
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
77
array_namespace,
88
device,
9+
is_array_api_obj,
10+
is_array_api_strict_namespace,
11+
is_cupy_array,
912
is_cupy_namespace,
1013
is_dask_array,
1114
is_dask_namespace,
1215
is_jax_array,
1316
is_jax_namespace,
17+
is_numpy_array,
18+
is_numpy_namespace,
19+
is_pydata_sparse_array,
1420
is_pydata_sparse_namespace,
21+
is_torch_array,
1522
is_torch_namespace,
1623
is_writeable_array,
1724
size,
@@ -20,12 +27,19 @@ def test_vendor_compat():
2027
x = xp.asarray([1, 2, 3])
2128
assert array_namespace(x) is xp
2229
device(x)
30+
assert is_array_api_obj(x)
31+
assert is_array_api_strict_namespace(xp)
32+
assert not is_cupy_array(x)
2333
assert not is_cupy_namespace(xp)
2434
assert not is_dask_array(x)
2535
assert not is_dask_namespace(xp)
2636
assert not is_jax_array(x)
2737
assert not is_jax_namespace(xp)
38+
assert not is_numpy_array(x)
39+
assert not is_numpy_namespace(xp)
40+
assert not is_pydata_sparse_array(x)
2841
assert not is_pydata_sparse_namespace(xp)
42+
assert not is_torch_array(x)
2943
assert not is_torch_namespace(xp)
3044
assert is_writeable_array(x)
3145
assert size(x) == 3

0 commit comments

Comments
 (0)