@@ -200,7 +200,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
200
200
if is_dask_namespace (xp ):
201
201
for name , func , tags in iter_tagged ():
202
202
n = tags ["allow_dask_compute" ]
203
- wrapped = _allow_dask_compute (func , n )
203
+ wrapped = _dask_wrap (func , n )
204
204
monkeypatch .setitem (globals_ , name , wrapped )
205
205
206
206
elif is_jax_namespace (xp ):
@@ -256,13 +256,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any:
256
256
return dask .get (dsk , keys , ** kwargs ) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
257
257
258
258
259
- def _allow_dask_compute (
259
+ def _dask_wrap (
260
260
func : Callable [P , T ], n : int
261
261
) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01
262
262
"""
263
263
Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times.
264
+
265
+ After the function returns, materialize the graph in order to re-raise exceptions.
264
266
"""
265
- import dask . config
267
+ import dask
266
268
267
269
func_name = getattr (func , "__name__" , str (func ))
268
270
n_str = f"only up to { n } " if n else "no"
@@ -276,7 +278,12 @@ def _allow_dask_compute(
276
278
@wraps (func )
277
279
def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
278
280
scheduler = CountingDaskScheduler (n , msg )
279
- with dask .config .set ({"scheduler" : scheduler }):
280
- return func (* args , ** kwargs )
281
+ with dask .config .set ({"scheduler" : scheduler }): # pyright: ignore[reportPrivateImportUsage]
282
+ out = func (* args , ** kwargs )
283
+
284
+ # Block until the graph materializes and reraise exceptions. This allows
285
+ # `pytest.raises` and `pytest.warns` to work as expected. Note that this would
286
+ # not work on scheduler='distributed', as it would not block.
287
+ return dask .persist (out , scheduler = "threads" )[0 ] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
281
288
282
289
return wrapper
0 commit comments