Skip to content

Commit 5ef0e18

Browse files
authored
Merge pull request #228 from crusaderky/is_lazy_array
ENH: is_lazy_array()
2 parents e5dd419 + 7950eaa commit 5ef0e18

File tree

3 files changed

+107
-6
lines changed

3 files changed

+107
-6
lines changed

array_api_compat/common/_helpers.py

+58
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,63 @@ def is_writeable_array(x) -> bool:
824824
return True
825825

826826

827+
def is_lazy_array(x) -> bool:
828+
"""Return True if x is potentially a future or it may be otherwise impossible or
829+
expensive to eagerly read its contents, regardless of their size, e.g. by
830+
calling ``bool(x)`` or ``float(x)``.
831+
832+
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
833+
cheap as long as the array has the right dtype and size.
834+
835+
Note
836+
----
837+
This function errs on the side of caution for array types that may or may not be
838+
lazy, e.g. JAX arrays, by always returning True for them.
839+
"""
840+
if (
841+
is_numpy_array(x)
842+
or is_cupy_array(x)
843+
or is_torch_array(x)
844+
or is_pydata_sparse_array(x)
845+
):
846+
return False
847+
848+
# **JAX note:** while it is possible to determine if you're inside or outside
849+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
850+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
851+
852+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
853+
# This behaviour, while impossible to change without breaking backwards
854+
# compatibility, is highly detrimental to performance as the whole graph will end
855+
# up being computed multiple times.
856+
857+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
858+
return True
859+
860+
# Unknown Array API compatible object. Note that this test may have dire consequences
861+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
862+
# on __bool__ (dask is one such example, which however is special-cased above).
863+
864+
# Select a single point of the array
865+
s = size(x)
866+
if s is None:
867+
return True
868+
xp = array_namespace(x)
869+
if s > 1:
870+
x = xp.reshape(x, (-1,))[0]
871+
# Cast to dtype=bool and deal with size 0 arrays
872+
x = xp.any(x)
873+
874+
try:
875+
bool(x)
876+
return False
877+
# The Array API standard dictactes that __bool__ should raise TypeError if the
878+
# output cannot be defined.
879+
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
880+
except Exception:
881+
return True
882+
883+
827884
__all__ = [
828885
"array_namespace",
829886
"device",
@@ -845,6 +902,7 @@ def is_writeable_array(x) -> bool:
845902
"is_pydata_sparse_array",
846903
"is_pydata_sparse_namespace",
847904
"is_writeable_array",
905+
"is_lazy_array",
848906
"size",
849907
"to_device",
850908
]

docs/helper-functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ yet.
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
5454
.. autofunction:: is_writeable_array
55+
.. autofunction:: is_lazy_array
5556
.. autofunction:: is_numpy_namespace
5657
.. autofunction:: is_cupy_namespace
5758
.. autofunction:: is_torch_namespace

tests/test_common.py

+48-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import math
2+
3+
import pytest
4+
import numpy as np
5+
import array
6+
from numpy.testing import assert_allclose
7+
18
from array_api_compat import ( # noqa: F401
29
is_numpy_array, is_cupy_array, is_torch_array,
310
is_dask_array, is_jax_array, is_pydata_sparse_array,
@@ -6,15 +13,10 @@
613
)
714

815
from array_api_compat import (
9-
device, is_array_api_obj, is_writeable_array, size, to_device
16+
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
1017
)
1118
from ._helpers import import_, wrapped_libraries, all_libraries
1219

13-
import pytest
14-
import numpy as np
15-
import array
16-
from numpy.testing import assert_allclose
17-
1820
is_array_functions = {
1921
'numpy': 'is_numpy_array',
2022
'cupy': 'is_cupy_array',
@@ -115,6 +117,45 @@ def test_size_none(library):
115117
assert size(x) in (None, 5)
116118

117119

120+
@pytest.mark.parametrize("library", all_libraries)
121+
def test_is_lazy_array(library):
122+
lib = import_(library)
123+
x = lib.asarray([1, 2, 3])
124+
assert isinstance(is_lazy_array(x), bool)
125+
126+
127+
@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan), (None, ), (1, None)])
128+
def test_is_lazy_array_nan_size(shape, monkeypatch):
129+
"""Test is_lazy_array() on an unknown Array API compliant object
130+
with NaN (like Dask) or None (like ndonnx) in its shape
131+
"""
132+
xp = import_("array_api_strict")
133+
x = xp.asarray(1)
134+
assert not is_lazy_array(x)
135+
monkeypatch.setattr(type(x), "shape", shape)
136+
assert is_lazy_array(x)
137+
138+
139+
@pytest.mark.parametrize("exc", [TypeError, AssertionError])
140+
def test_is_lazy_array_bool_raises(exc, monkeypatch):
141+
"""Test is_lazy_array() on an unknown Array API compliant object
142+
where calling bool() raises:
143+
- TypeError: e.g. like jitted JAX. This is the proper exception which
144+
lazy arrays should raise as per the Array API specification
145+
- something else: e.g. like Dask, where bool() triggers compute()
146+
which can result in any kind of exception to be raised
147+
"""
148+
xp = import_("array_api_strict")
149+
x = xp.asarray(1)
150+
assert not is_lazy_array(x)
151+
152+
def __bool__(self):
153+
raise exc("Hello world")
154+
155+
monkeypatch.setattr(type(x), "__bool__", __bool__)
156+
assert is_lazy_array(x)
157+
158+
118159
@pytest.mark.parametrize("library", all_libraries)
119160
def test_device(library):
120161
xp = import_(library, wrapper=True)
@@ -172,6 +213,7 @@ def test_asarray_cross_library(source_library, target_library, request):
172213

173214
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
174215

216+
175217
@pytest.mark.parametrize("library", wrapped_libraries)
176218
def test_asarray_copy(library):
177219
# Note, we have this test here because the test suite currently doesn't

0 commit comments

Comments
 (0)