Skip to content

Commit b24c218

Browse files
BUG: lazy_xp_function crashes with Cython ufuncs (#153)
* BUG: lazy_xp_function crashes with ufuncs * elaborate on comment --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 7b8e4f1 commit b24c218

File tree

2 files changed

+69
-19
lines changed

2 files changed

+69
-19
lines changed

src/array_api_extra/testing.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
88
from __future__ import annotations
99

10-
from collections.abc import Callable, Iterable, Sequence
10+
import contextlib
11+
from collections.abc import Callable, Iterable, Iterator, Sequence
1112
from functools import wraps
1213
from types import ModuleType
1314
from typing import TYPE_CHECKING, Any, TypeVar, cast
@@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]:
4243

4344
T = TypeVar("T")
4445

46+
_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit]
47+
4548

4649
def lazy_xp_function( # type: ignore[no-any-explicit]
4750
func: Callable[..., Any],
@@ -132,12 +135,16 @@ def test_myfunc(xp):
132135
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
133136
mymodule.myfunc(a) # This is not
134137
"""
135-
func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
136-
if jax_jit:
137-
func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess]
138-
"static_argnums": static_argnums,
139-
"static_argnames": static_argnames,
140-
}
138+
tags = {
139+
"allow_dask_compute": allow_dask_compute,
140+
"jax_jit": jax_jit,
141+
"static_argnums": static_argnums,
142+
"static_argnames": static_argnames,
143+
}
144+
try:
145+
func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
146+
except AttributeError: # @cython.vectorize
147+
_ufuncs_tags[func] = tags
141148

142149

143150
def patch_lazy_xp_functions(
@@ -179,24 +186,37 @@ def xp(request, monkeypatch):
179186
"""
180187
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
181188

182-
if is_dask_namespace(xp):
189+
def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
183190
for name, func in globals_.items():
184-
n = getattr(func, "allow_dask_compute", None)
185-
if n is not None:
186-
assert isinstance(n, int)
187-
wrapped = _allow_dask_compute(func, n)
188-
monkeypatch.setitem(globals_, name, wrapped)
191+
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
192+
with contextlib.suppress(AttributeError):
193+
tags = func._lazy_xp_function # pylint: disable=protected-access
194+
if tags is None:
195+
with contextlib.suppress(KeyError, TypeError):
196+
tags = _ufuncs_tags[func]
197+
if tags is not None:
198+
yield name, func, tags
199+
200+
if is_dask_namespace(xp):
201+
for name, func, tags in iter_tagged():
202+
n = tags["allow_dask_compute"]
203+
wrapped = _allow_dask_compute(func, n)
204+
monkeypatch.setitem(globals_, name, wrapped)
189205

190206
elif is_jax_namespace(xp):
191207
import jax
192208

193-
for name, func in globals_.items():
194-
kwargs = cast( # type: ignore[no-any-explicit]
195-
"dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None)
196-
)
197-
if kwargs is not None:
209+
for name, func, tags in iter_tagged():
210+
if tags["jax_jit"]:
198211
# suppress unused-ignore to run mypy in -e lint as well as -e dev
199-
wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore]
212+
wrapped = cast( # type: ignore[no-any-explicit]
213+
Callable[..., Any],
214+
jax.jit(
215+
func,
216+
static_argnums=tags["static_argnums"],
217+
static_argnames=tags["static_argnames"],
218+
),
219+
)
200220
monkeypatch.setitem(globals_, name, wrapped)
201221

202222

tests/test_testing.py

+30
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,33 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
202202
xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0]))
203203
xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0]))
204204
xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0]))
205+
206+
207+
try:
208+
# Test an arbitrary Cython ufunc (@cython.vectorize).
209+
# When SCIPY_ARRAY_API is not set, this is the same as
210+
# scipy.special.erf.
211+
from scipy.special._ufuncs import erf # type: ignore[import-not-found]
212+
213+
lazy_xp_function(erf) # pyright: ignore[reportUnknownArgumentType]
214+
except ImportError:
215+
erf = None
216+
217+
218+
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # torch
219+
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
220+
pytest.importorskip("scipy")
221+
assert erf is not None
222+
x = xp.asarray([6.0, 7.0])
223+
if library in (Backend.ARRAY_API_STRICT, Backend.JAX):
224+
# array-api-strict arrays are auto-converted to numpy
225+
# which results in an assertion error for mismatched namespaces
226+
# eager jax arrays are auto-converted to numpy in eager jax
227+
# and fail in jax.jit (which lazy_xp_function tests here)
228+
with pytest.raises((TypeError, AssertionError)):
229+
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))
230+
else:
231+
# cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
232+
# note that when sparse reduces to scalar it returns a np.generic, which
233+
# would make xp_assert_equal fail.
234+
xp_assert_equal(erf(x), xp.asarray([1.0, 1.0]))

0 commit comments

Comments
 (0)