Skip to content

Commit d7b4111

Browse files
authored
Merge pull request #191 from asmeurer/2023.12
Update __array_api_version__ to 2023.12
2 parents 522a608 + 273d54e commit d7b4111

File tree

7 files changed

+23
-16
lines changed

7 files changed

+23
-16
lines changed

array_api_compat/common/_helpers.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def is_torch_namespace(xp) -> bool:
317317
is_array_api_strict_namespace
318318
"""
319319
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320-
320+
321321

322322
def is_ndonnx_namespace(xp):
323323
"""
@@ -415,10 +415,11 @@ def is_array_api_strict_namespace(xp):
415415
return xp.__name__ == 'array_api_strict'
416416

417417
def _check_api_version(api_version):
418-
if api_version == '2021.12':
419-
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
420-
elif api_version is not None and api_version != '2022.12':
421-
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
418+
if api_version in ['2021.12', '2022.12']:
419+
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
420+
elif api_version is not None and api_version not in ['2021.12', '2022.12',
421+
'2023.12']:
422+
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
422423

423424
def array_namespace(*xs, api_version=None, use_compat=None):
424425
"""
@@ -431,7 +432,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
431432
432433
api_version: str
433434
The newest version of the spec that you need support for (currently
434-
the compat library wrapped APIs support v2022.12).
435+
the compat library wrapped APIs support v2023.12).
435436
436437
use_compat: bool or None
437438
If None (the default), the native namespace will be returned if it is

array_api_compat/cupy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313

1414
from ..common._helpers import * # noqa: F401,F403
1515

16-
__array_api_version__ = '2022.12'
16+
__array_api_version__ = '2023.12'

array_api_compat/dask/array/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
55

6-
__array_api_version__ = '2022.12'
6+
__array_api_version__ = '2023.12'
77

88
__import__(__package__ + '.linalg')
99
__import__(__package__ + '.fft')

array_api_compat/numpy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
except ImportError:
2828
pass
2929

30-
__array_api_version__ = '2022.12'
30+
__array_api_version__ = '2023.12'

array_api_compat/torch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121

2222
from ..common._helpers import * # noqa: F403
2323

24-
__array_api_version__ = '2022.12'
24+
__array_api_version__ = '2023.12'

docs/index.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this
1212
requires making backwards incompatible changes in many cases, so this will
1313
take some time.
1414

15-
Currently all libraries here are implemented against the [2022.12
16-
version](https://data-apis.org/array-api/2022.12/) of the standard.
15+
Currently all libraries here are implemented against the [2023.12
16+
version](https://data-apis.org/array-api/2023.12/) of the standard.
1717

1818
## Installation
1919

tests/test_array_namespace.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ._helpers import import_, all_libraries, wrapped_libraries
1414

1515
@pytest.mark.parametrize("use_compat", [True, False, None])
16-
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
16+
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
1717
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
1818
def test_array_namespace(library, api_version, use_compat):
1919
xp = import_(library)
@@ -94,14 +94,20 @@ def test_array_namespace_errors_torch():
9494
def test_api_version():
9595
x = torch.asarray([1, 2])
9696
torch_ = import_("torch", wrapper=True)
97-
assert array_namespace(x, api_version="2022.12") == torch_
97+
assert array_namespace(x, api_version="2023.12") == torch_
9898
assert array_namespace(x, api_version=None) == torch_
9999
assert array_namespace(x) == torch_
100100
# Should issue a warning
101101
with warnings.catch_warnings(record=True) as w:
102102
assert array_namespace(x, api_version="2021.12") == torch_
103-
assert len(w) == 1
104-
assert "2021.12" in str(w[0].message)
103+
assert len(w) == 1
104+
assert "2021.12" in str(w[0].message)
105+
106+
# Should issue a warning
107+
with warnings.catch_warnings(record=True) as w:
108+
assert array_namespace(x, api_version="2022.12") == torch_
109+
assert len(w) == 1
110+
assert "2022.12" in str(w[0].message)
105111

106112
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
107113

0 commit comments

Comments
 (0)