Skip to content

Commit ce7342e

Browse files
authored
Merge pull request #158 from crusaderky/lazy_xp_modules
ENH: `lazy_xp_function` namespaces support
2 parents 2f7b4d9 + 8cac0e4 commit ce7342e

File tree

2 files changed

+110
-28
lines changed

2 files changed

+110
-28
lines changed

src/array_api_extra/testing.py

+73-28
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
5656
"""
5757
Tag a function to be tested on lazy backends.
5858
59-
Tag a function, which must be imported in the test module globals, so that when any
60-
tests defined in the same module are executed with ``xp=jax.numpy`` the function is
61-
replaced with a jitted version of itself, and when it is executed with
59+
Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
60+
function is replaced with a jitted version of itself, and when it is executed with
6261
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
6362
This will be later expanded to provide test coverage for other lazy backends.
6463
@@ -120,19 +119,59 @@ def test_myfunc(xp):
120119
121120
Notes
122121
-----
123-
A test function can circumvent this monkey-patching system by calling `func` as an
124-
attribute of the original module. You need to sanitize your code to make sure this
125-
does not happen.
122+
In order for this tag to be effective, the test function must be imported into the
123+
test module globals without its namespace; alternatively its namespace must be
124+
declared in a ``lazy_xp_modules`` list in the test module globals.
126125
127-
Example::
126+
Example 1::
128127
129-
import mymodule from mymodule import myfunc
128+
from mymodule import myfunc
130129
131130
lazy_xp_function(myfunc)
132131
133132
def test_myfunc(xp):
134-
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
135-
mymodule.myfunc(a) # This is not
133+
x = myfunc(xp.asarray([1, 2]))
134+
135+
Example 2::
136+
137+
import mymodule
138+
139+
lazy_xp_modules = [mymodule]
140+
lazy_xp_function(mymodule.myfunc)
141+
142+
def test_myfunc(xp):
143+
x = mymodule.myfunc(xp.asarray([1, 2]))
144+
145+
A test function can circumvent this monkey-patching system by using a namespace
146+
outside of the two above patterns. You need to sanitize your code to make sure this
147+
only happens intentionally.
148+
149+
Example 1::
150+
151+
import mymodule
152+
from mymodule import myfunc
153+
154+
lazy_xp_function(myfunc)
155+
156+
def test_myfunc(xp):
157+
a = xp.asarray([1, 2])
158+
b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
159+
c = mymodule.myfunc(a) # This is not
160+
161+
Example 2::
162+
163+
import mymodule
164+
165+
class naked:
166+
myfunc = mymodule.myfunc
167+
168+
lazy_xp_modules = [mymodule]
169+
lazy_xp_function(mymodule.myfunc)
170+
171+
def test_myfunc(xp):
172+
a = xp.asarray([1, 2])
173+
b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
174+
c = naked.myfunc(a) # This is not
136175
"""
137176
tags = {
138177
"allow_dask_compute": allow_dask_compute,
@@ -153,11 +192,13 @@ def patch_lazy_xp_functions(
153192
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
154193
155194
If ``xp==jax.numpy``, search for all functions which have been tagged with
156-
:func:`lazy_xp_function` in the globals of the module that defines the current test
195+
:func:`lazy_xp_function` in the globals of the module that defines the current test,
196+
as well as in the ``lazy_xp_modules`` list in the globals of the same module,
157197
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
158198
159199
If ``xp==dask.array``, wrap the functions with a decorator that disables
160-
``compute()`` and ``persist()``.
200+
``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
201+
eagerly.
161202
162203
This function should be typically called by your library's `xp` fixture that runs
163204
tests on multiple backends::
@@ -183,29 +224,33 @@ def xp(request, monkeypatch):
183224
lazy_xp_function : Tag a function to be tested on lazy backends.
184225
pytest.FixtureRequest : `request` test function parameter.
185226
"""
186-
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
187-
188-
def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
189-
for name, func in globals_.items():
190-
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
191-
with contextlib.suppress(AttributeError):
192-
tags = func._lazy_xp_function # pylint: disable=protected-access
193-
if tags is None:
194-
with contextlib.suppress(KeyError, TypeError):
195-
tags = _ufuncs_tags[func]
196-
if tags is not None:
197-
yield name, func, tags
227+
mod = cast(ModuleType, request.module)
228+
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]
229+
230+
def iter_tagged() -> ( # type: ignore[no-any-explicit]
231+
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
232+
):
233+
for mod in mods:
234+
for name, func in mod.__dict__.items():
235+
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
236+
with contextlib.suppress(AttributeError):
237+
tags = func._lazy_xp_function # pylint: disable=protected-access
238+
if tags is None:
239+
with contextlib.suppress(KeyError, TypeError):
240+
tags = _ufuncs_tags[func]
241+
if tags is not None:
242+
yield mod, name, func, tags
198243

199244
if is_dask_namespace(xp):
200-
for name, func, tags in iter_tagged():
245+
for mod, name, func, tags in iter_tagged():
201246
n = tags["allow_dask_compute"]
202247
wrapped = _dask_wrap(func, n)
203-
monkeypatch.setitem(globals_, name, wrapped)
248+
monkeypatch.setattr(mod, name, wrapped)
204249

205250
elif is_jax_namespace(xp):
206251
import jax
207252

208-
for name, func, tags in iter_tagged():
253+
for mod, name, func, tags in iter_tagged():
209254
if tags["jax_jit"]:
210255
# suppress unused-ignore to run mypy in -e lint as well as -e dev
211256
wrapped = cast( # type: ignore[no-any-explicit]
@@ -216,7 +261,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
216261
static_argnames=tags["static_argnames"],
217262
),
218263
)
219-
monkeypatch.setitem(globals_, name, wrapped)
264+
monkeypatch.setattr(mod, name, wrapped)
220265

221266

222267
class CountingDaskScheduler(SchedulerGetCallable):

tests/test_testing.py

+37
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def non_materializable(x: Array) -> Array:
108108
and it will trigger an expensive computation in dask.
109109
"""
110110
xp = array_namespace(x)
111+
# Crashes inside jax.jit
111112
# On dask, this triggers two computations of the whole graph
112113
if xp.any(x < 0.0) or xp.any(x > 10.0):
113114
msg = "Values must be in the [0, 10] range"
@@ -261,3 +262,39 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
261262
x = da.arange(3)
262263
with pytest.raises(ValueError, match="Hello world"):
263264
dask_raises(x)
265+
266+
267+
class Wrapped:
268+
def f(x: Array) -> Array: # noqa: N805 # pyright: ignore[reportSelfClsParameterName]
269+
xp = array_namespace(x)
270+
# Crash in jax.jit and trigger compute() on dask
271+
if not xp.all(x):
272+
msg = "Values must be non-zero"
273+
raise ValueError(msg)
274+
return x
275+
276+
277+
class Naked:
278+
f = Wrapped.f # pyright: ignore[reportUnannotatedClassAttribute]
279+
280+
281+
lazy_xp_function(Wrapped.f)
282+
lazy_xp_modules = [Wrapped]
283+
284+
285+
def test_lazy_xp_modules(xp: ModuleType, library: Backend):
286+
x = xp.asarray([1.0, 2.0])
287+
y = Naked.f(x)
288+
xp_assert_equal(y, x)
289+
290+
if library is Backend.JAX:
291+
with pytest.raises(
292+
TypeError, match="Attempted boolean conversion of traced array"
293+
):
294+
Wrapped.f(x)
295+
elif library is Backend.DASK:
296+
with pytest.raises(AssertionError, match=r"dask\.compute"):
297+
Wrapped.f(x)
298+
else:
299+
y = Wrapped.f(x)
300+
xp_assert_equal(y, x)

0 commit comments

Comments
 (0)