Skip to content

Commit 82c3ec4

Browse files
authored
Merge pull request #259 from honno/info-stub-fix
Add missing `xp.__array_namespace_info__()` stub
2 parents bc1e37e + 6ebe822 commit 82c3ec4

File tree

4 files changed

+21
-36
lines changed

4 files changed

+21
-36
lines changed

README.md

+15-30
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
This is the test suite for array libraries adopting the [Python Array API
44
standard](https://data-apis.org/array-api/latest).
55

6-
Note the suite is still a **work in progress**. Feedback and contributions are
7-
welcome!
6+
Keeping full coverage of the spec is an on-going priority as the Array API evolves.
7+
Feedback and contributions are welcome!
88

99
## Quickstart
1010

@@ -33,11 +33,23 @@ You need to specify the array library to test. It can be specified via the
3333
`ARRAY_API_TESTS_MODULE` environment variable, e.g.
3434

3535
```bash
36-
$ export ARRAY_API_TESTS_MODULE=numpy.array_api
36+
$ export ARRAY_API_TESTS_MODULE=array_api_strict
3737
```
3838

3939
Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
4040

41+
### Specifying the API version
42+
43+
You can specify the API version to use when testing via the
44+
`ARRAY_API_TESTS_VERSION` environment variable, e.g.
45+
46+
```bash
47+
$ export ARRAY_API_TESTS_VERSION="2023.12"
48+
```
49+
50+
Currently this defaults to the array module's `__array_api_version__` value, and
51+
if that attribute doesn't exist then we fallback to `"2021.12"`.
52+
4153
### Run the suite
4254

4355
Simply run `pytest` against the `array_api_tests/` folder to run the full suite.
@@ -154,13 +166,6 @@ library to fail.
154166

155167
### Configuration
156168

157-
#### API version
158-
159-
You can specify the API version to use when testing via the
160-
`ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the
161-
array module's `__array_api_version__` value, and if that attribute doesn't
162-
exist then we fallback to `"2021.12"`.
163-
164169
#### Data-dependent shapes
165170

166171
Use the `--disable-data-dependent-shapes` flag to skip testing functions which have
@@ -349,26 +354,6 @@ into a release. If you want, you can add release notes, which GitHub can
349354
generate for you.
350355

351356

352-
## Future plans
353-
354-
Keeping full coverage of the spec is an on-going priority as the Array API
355-
evolves.
356-
357-
Additionally, we have features and general improvements planned. Work on such
358-
functionality is guided primarily by the concerete needs of developers
359-
implementing and using the Array API—be sure to [let us
360-
know](https://github.com/data-apis/array-api-tests/issues) any limitations you
361-
come across.
362-
363-
* A dependency graph for every test case, which could be used to modify pytest's
364-
collection so that low-dependency tests are run first, and tests with faulty
365-
dependencies would skip/xfail.
366-
367-
* In some tests we've found it difficult to find appropaite assertion parameters
368-
for output values (particularly epsilons for floating-point outputs), so we
369-
need to review these and either implement assertions or properly note the lack
370-
thereof.
371-
372357
---
373358

374359
<sup>1</sup>The only exceptions to having just one primary test per function are:

array_api_tests/stubs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
category_to_funcs: Dict[str, List[FunctionType]] = {}
4646
for name, mod in name_to_mod.items():
47-
if name.endswith("_functions"):
47+
if name.endswith("_functions") or name == "info": # info functions file just named info.py
4848
category = name.replace("_functions", "")
4949
objects = [getattr(mod, name) for name in mod.__all__]
5050
assert all(isinstance(o, FunctionType) for o in objects) # sanity check

array_api_tests/test_linalg.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from . import dtype_helpers as dh
3838
from . import pytest_helpers as ph
3939
from . import shape_helpers as sh
40+
from . import api_version
4041
from .typing import Array
4142

4243
from . import _array_module
@@ -873,7 +874,7 @@ def test_trace(x, kw):
873874
# See https://github.com/data-apis/array-api-tests/issues/160
874875
if x.dtype in dh.uint_dtypes:
875876
assert dh.is_int_dtype(res.dtype) # sanity check
876-
else:
877+
elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234
877878
ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype)
878879

879880
n, m = x.shape[-2:]

array_api_tests/test_statistical_functions.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
14-
from . import xps
14+
from . import api_version, xps
1515
from ._array_module import _UndefinedStub
1616
from .typing import DataType
1717

@@ -148,7 +148,7 @@ def test_prod(x, data):
148148
# See https://github.com/data-apis/array-api-tests/issues/106
149149
if x.dtype in dh.uint_dtypes:
150150
assert dh.is_int_dtype(out.dtype) # sanity check
151-
else:
151+
elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234
152152
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
153153
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
154154
ph.assert_keepdimable_shape(
@@ -207,7 +207,6 @@ def test_std(x, data):
207207

208208

209209
@pytest.mark.unvectorized
210-
@pytest.mark.skip("flaky") # TODO: fix!
211210
@given(
212211
x=hh.arrays(
213212
dtype=xps.numeric_dtypes(),
@@ -238,7 +237,7 @@ def test_sum(x, data):
238237
# See https://github.com/data-apis/array-api-tests/issues/160
239238
if x.dtype in dh.uint_dtypes:
240239
assert dh.is_int_dtype(out.dtype) # sanity check
241-
else:
240+
elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234
242241
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
243242
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
244243
ph.assert_keepdimable_shape(

0 commit comments

Comments
 (0)