Skip to content

Commit 75e5166

Browse files
authored
Merge pull request #155 from crusaderky/lazy_raise
2 parents b24c218 + a925bf9 commit 75e5166

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

Diff for: src/array_api_extra/testing.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
200200
if is_dask_namespace(xp):
201201
for name, func, tags in iter_tagged():
202202
n = tags["allow_dask_compute"]
203-
wrapped = _allow_dask_compute(func, n)
203+
wrapped = _dask_wrap(func, n)
204204
monkeypatch.setitem(globals_, name, wrapped)
205205

206206
elif is_jax_namespace(xp):
@@ -256,13 +256,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any:
256256
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
257257

258258

259-
def _allow_dask_compute(
259+
def _dask_wrap(
260260
func: Callable[P, T], n: int
261261
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
262262
"""
263263
Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times.
264+
265+
After the function returns, materialize the graph in order to re-raise exceptions.
264266
"""
265-
import dask.config
267+
import dask
266268

267269
func_name = getattr(func, "__name__", str(func))
268270
n_str = f"only up to {n}" if n else "no"
@@ -276,7 +278,12 @@ def _allow_dask_compute(
276278
@wraps(func)
277279
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
278280
scheduler = CountingDaskScheduler(n, msg)
279-
with dask.config.set({"scheduler": scheduler}):
280-
return func(*args, **kwargs)
281+
with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage]
282+
out = func(*args, **kwargs)
283+
284+
# Block until the graph materializes and reraise exceptions. This allows
285+
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
286+
# not work on scheduler='distributed', as it would not block.
287+
return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
281288

282289
return wrapper

Diff for: tests/test_testing.py

+29
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,32 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
232232
# note that when sparse reduces to scalar it returns a np.generic, which
233233
# would make xp_assert_equal fail.
234234
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))
235+
236+
237+
def dask_raises(x: Array) -> Array:
238+
def _raises(x: Array) -> Array:
239+
# Test that map_blocks doesn't eagerly call the function;
240+
# dtype and meta should be sufficient to skip the trial run.
241+
assert x.shape == (3,)
242+
msg = "Hello world"
243+
raise ValueError(msg)
244+
245+
return x.map_blocks(_raises, dtype=x.dtype, meta=x._meta)
246+
247+
248+
lazy_xp_function(dask_raises)
249+
250+
251+
def test_lazy_xp_function_eagerly_raises(da: ModuleType):
252+
"""Test that the pattern::
253+
254+
with pytest.raises(Exception):
255+
func(x)
256+
257+
works with Dask, even though it normally wouldn't as we're disregarding the func
258+
output so the graph would not be ordinarily materialized.
259+
lazy_xp_function contains ad-hoc code to materialize and reraise exceptions.
260+
"""
261+
x = da.arange(3)
262+
with pytest.raises(ValueError, match="Hello world"):
263+
dask_raises(x)

0 commit comments

Comments
 (0)