|
| 1 | +import math |
| 2 | + |
| 3 | +import pytest |
| 4 | +import numpy as np |
| 5 | +import array |
| 6 | +from numpy.testing import assert_allclose |
| 7 | + |
1 | 8 | from array_api_compat import ( # noqa: F401
|
2 | 9 | is_numpy_array, is_cupy_array, is_torch_array,
|
3 | 10 | is_dask_array, is_jax_array, is_pydata_sparse_array,
|
|
6 | 13 | )
|
7 | 14 |
|
8 | 15 | from array_api_compat import (
|
9 |
| - device, is_array_api_obj, is_writeable_array, size, to_device |
| 16 | + device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device |
10 | 17 | )
|
11 | 18 | from ._helpers import import_, wrapped_libraries, all_libraries
|
12 | 19 |
|
13 |
| -import pytest |
14 |
| -import numpy as np |
15 |
| -import array |
16 |
| -from numpy.testing import assert_allclose |
17 |
| - |
18 | 20 | is_array_functions = {
|
19 | 21 | 'numpy': 'is_numpy_array',
|
20 | 22 | 'cupy': 'is_cupy_array',
|
@@ -115,6 +117,45 @@ def test_size_none(library):
|
115 | 117 | assert size(x) in (None, 5)
|
116 | 118 |
|
117 | 119 |
|
| 120 | +@pytest.mark.parametrize("library", all_libraries) |
| 121 | +def test_is_lazy_array(library): |
| 122 | + lib = import_(library) |
| 123 | + x = lib.asarray([1, 2, 3]) |
| 124 | + assert isinstance(is_lazy_array(x), bool) |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)]) |
| 128 | +def test_is_lazy_array_nan_size(shape, monkeypatch): |
| 129 | + """Test is_lazy_array() on an unknown Array API compliant object |
| 130 | + with NaN (like Dask) or None (like ndonnx) in its shape |
| 131 | + """ |
| 132 | + xp = import_("array_api_strict") |
| 133 | + x = xp.asarray(1) |
| 134 | + assert not is_lazy_array(x) |
| 135 | + monkeypatch.setattr(type(x), "shape", shape) |
| 136 | + assert is_lazy_array(x) |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.parametrize("exc", [TypeError, AssertionError]) |
| 140 | +def test_is_lazy_array_bool_raises(exc, monkeypatch): |
| 141 | + """Test is_lazy_array() on an unknown Array API compliant object |
| 142 | + where calling bool() raises: |
| 143 | + - TypeError: e.g. like jitted JAX. This is the proper exception which |
| 144 | + lazy arrays should raise as per the Array API specification |
| 145 | + - something else: e.g. like Dask, where bool() triggers compute() |
| 146 | + which can result in any kind of exception to be raised |
| 147 | + """ |
| 148 | + xp = import_("array_api_strict") |
| 149 | + x = xp.asarray(1) |
| 150 | + assert not is_lazy_array(x) |
| 151 | + |
| 152 | + def __bool__(self): |
| 153 | + raise exc("Hello world") |
| 154 | + |
| 155 | + monkeypatch.setattr(type(x), "__bool__", __bool__) |
| 156 | + assert is_lazy_array(x) |
| 157 | + |
| 158 | + |
118 | 159 | @pytest.mark.parametrize("library", all_libraries)
|
119 | 160 | def test_device(library):
|
120 | 161 | xp = import_(library, wrapper=True)
|
@@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request):
|
172 | 213 |
|
173 | 214 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
|
174 | 215 |
|
| 216 | + |
175 | 217 | @pytest.mark.parametrize("library", wrapped_libraries)
|
176 | 218 | def test_asarray_copy(library):
|
177 | 219 | # Note, we have this test here because the test suite currently doesn't
|
|
0 commit comments