@@ -56,9 +56,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
56
56
"""
57
57
Tag a function to be tested on lazy backends.
58
58
59
- Tag a function, which must be imported in the test module globals, so that when any
60
- tests defined in the same module are executed with ``xp=jax.numpy`` the function is
61
- replaced with a jitted version of itself, and when it is executed with
59
+ Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
60
+ function is replaced with a jitted version of itself, and when it is executed with
62
61
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
63
62
This will be later expanded to provide test coverage for other lazy backends.
64
63
@@ -120,19 +119,59 @@ def test_myfunc(xp):
120
119
121
120
Notes
122
121
-----
123
- A test function can circumvent this monkey-patching system by calling `func` as an
124
- attribute of the original module. You need to sanitize your code to make sure this
125
- does not happen .
122
+ In order for this tag to be effective, the test function must be imported into the
123
+ test module globals without its namespace; alternatively its namespace must be
124
+ declared in a ``lazy_xp_modules`` list in the test module globals .
126
125
127
- Example::
126
+ Example 1 ::
128
127
129
- import mymodule from mymodule import myfunc
128
+ from mymodule import myfunc
130
129
131
130
lazy_xp_function(myfunc)
132
131
133
132
def test_myfunc(xp):
134
- a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
135
- mymodule.myfunc(a) # This is not
133
+ x = myfunc(xp.asarray([1, 2]))
134
+
135
+ Example 2::
136
+
137
+ import mymodule
138
+
139
+ lazy_xp_modules = [mymodule]
140
+ lazy_xp_function(mymodule.myfunc)
141
+
142
+ def test_myfunc(xp):
143
+ x = mymodule.myfunc(xp.asarray([1, 2]))
144
+
145
+ A test function can circumvent this monkey-patching system by using a namespace
146
+ outside of the two above patterns. You need to sanitize your code to make sure this
147
+ only happens intentionally.
148
+
149
+ Example 1::
150
+
151
+ import mymodule
152
+ from mymodule import myfunc
153
+
154
+ lazy_xp_function(myfunc)
155
+
156
+ def test_myfunc(xp):
157
+ a = xp.asarray([1, 2])
158
+ b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
159
+ c = mymodule.myfunc(a) # This is not
160
+
161
+ Example 2::
162
+
163
+ import mymodule
164
+
165
+ class naked:
166
+ myfunc = mymodule.myfunc
167
+
168
+ lazy_xp_modules = [mymodule]
169
+ lazy_xp_function(mymodule.myfunc)
170
+
171
+ def test_myfunc(xp):
172
+ a = xp.asarray([1, 2])
173
+ b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
174
+ c = naked.myfunc(a) # This is not
136
175
"""
137
176
tags = {
138
177
"allow_dask_compute" : allow_dask_compute ,
@@ -153,11 +192,13 @@ def patch_lazy_xp_functions(
153
192
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
154
193
155
194
If ``xp==jax.numpy``, search for all functions which have been tagged with
156
- :func:`lazy_xp_function` in the globals of the module that defines the current test
195
+ :func:`lazy_xp_function` in the globals of the module that defines the current test,
196
+ as well as in the ``lazy_xp_modules`` list in the globals of the same module,
157
197
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
158
198
159
199
If ``xp==dask.array``, wrap the functions with a decorator that disables
160
- ``compute()`` and ``persist()``.
200
+ ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
201
+ eagerly.
161
202
162
203
This function should be typically called by your library's `xp` fixture that runs
163
204
tests on multiple backends::
@@ -183,29 +224,33 @@ def xp(request, monkeypatch):
183
224
lazy_xp_function : Tag a function to be tested on lazy backends.
184
225
pytest.FixtureRequest : `request` test function parameter.
185
226
"""
186
- globals_ = cast ("dict[str, Any]" , request .module .__dict__ ) # type: ignore[no-any-explicit]
187
-
188
- def iter_tagged () -> Iterator [tuple [str , Callable [..., Any ], dict [str , Any ]]]: # type: ignore[no-any-explicit]
189
- for name , func in globals_ .items ():
190
- tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
191
- with contextlib .suppress (AttributeError ):
192
- tags = func ._lazy_xp_function # pylint: disable=protected-access
193
- if tags is None :
194
- with contextlib .suppress (KeyError , TypeError ):
195
- tags = _ufuncs_tags [func ]
196
- if tags is not None :
197
- yield name , func , tags
227
+ mod = cast (ModuleType , request .module )
228
+ mods = [mod , * cast (list [ModuleType ], getattr (mod , "lazy_xp_modules" , []))]
229
+
230
+ def iter_tagged () -> ( # type: ignore[no-any-explicit]
231
+ Iterator [tuple [ModuleType , str , Callable [..., Any ], dict [str , Any ]]]
232
+ ):
233
+ for mod in mods :
234
+ for name , func in mod .__dict__ .items ():
235
+ tags : dict [str , Any ] | None = None # type: ignore[no-any-explicit]
236
+ with contextlib .suppress (AttributeError ):
237
+ tags = func ._lazy_xp_function # pylint: disable=protected-access
238
+ if tags is None :
239
+ with contextlib .suppress (KeyError , TypeError ):
240
+ tags = _ufuncs_tags [func ]
241
+ if tags is not None :
242
+ yield mod , name , func , tags
198
243
199
244
if is_dask_namespace (xp ):
200
- for name , func , tags in iter_tagged ():
245
+ for mod , name , func , tags in iter_tagged ():
201
246
n = tags ["allow_dask_compute" ]
202
247
wrapped = _dask_wrap (func , n )
203
- monkeypatch .setitem ( globals_ , name , wrapped )
248
+ monkeypatch .setattr ( mod , name , wrapped )
204
249
205
250
elif is_jax_namespace (xp ):
206
251
import jax
207
252
208
- for name , func , tags in iter_tagged ():
253
+ for mod , name , func , tags in iter_tagged ():
209
254
if tags ["jax_jit" ]:
210
255
# suppress unused-ignore to run mypy in -e lint as well as -e dev
211
256
wrapped = cast ( # type: ignore[no-any-explicit]
@@ -216,7 +261,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
216
261
static_argnames = tags ["static_argnames" ],
217
262
),
218
263
)
219
- monkeypatch .setitem ( globals_ , name , wrapped )
264
+ monkeypatch .setattr ( mod , name , wrapped )
220
265
221
266
222
267
class CountingDaskScheduler (SchedulerGetCallable ):
0 commit comments