|
7 | 7 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
|
8 | 8 | from __future__ import annotations
|
9 | 9 |
|
10 |
| -from collections.abc import Callable, Iterable, Sequence |
| 10 | +import contextlib |
| 11 | +from collections.abc import Callable, Iterable, Iterator, Sequence |
11 | 12 | from functools import wraps
|
12 | 13 | from types import ModuleType
|
13 | 14 | from typing import TYPE_CHECKING, Any, TypeVar, cast
|
@@ -42,6 +43,8 @@ def override(func: Callable[P, T]) -> Callable[P, T]:
|
42 | 43 |
|
43 | 44 | T = TypeVar("T")
|
44 | 45 |
|
| 46 | +_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[no-any-explicit] |
| 47 | + |
45 | 48 |
|
46 | 49 | def lazy_xp_function( # type: ignore[no-any-explicit]
|
47 | 50 | func: Callable[..., Any],
|
@@ -132,12 +135,16 @@ def test_myfunc(xp):
|
132 | 135 | a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
|
133 | 136 | mymodule.myfunc(a) # This is not
|
134 | 137 | """
|
135 |
| - func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] |
136 |
| - if jax_jit: |
137 |
| - func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] |
138 |
| - "static_argnums": static_argnums, |
139 |
| - "static_argnames": static_argnames, |
140 |
| - } |
| 138 | + tags = { |
| 139 | + "allow_dask_compute": allow_dask_compute, |
| 140 | + "jax_jit": jax_jit, |
| 141 | + "static_argnums": static_argnums, |
| 142 | + "static_argnames": static_argnames, |
| 143 | + } |
| 144 | + try: |
| 145 | + func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] |
| 146 | + except AttributeError: # @cython.vectorize |
| 147 | + _ufuncs_tags[func] = tags |
141 | 148 |
|
142 | 149 |
|
143 | 150 | def patch_lazy_xp_functions(
|
@@ -179,24 +186,37 @@ def xp(request, monkeypatch):
|
179 | 186 | """
|
180 | 187 | globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]
|
181 | 188 |
|
182 |
| - if is_dask_namespace(xp): |
| 189 | + def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit] |
183 | 190 | for name, func in globals_.items():
|
184 |
| - n = getattr(func, "allow_dask_compute", None) |
185 |
| - if n is not None: |
186 |
| - assert isinstance(n, int) |
187 |
| - wrapped = _allow_dask_compute(func, n) |
188 |
| - monkeypatch.setitem(globals_, name, wrapped) |
| 191 | + tags: dict[str, Any] | None = None # type: ignore[no-any-explicit] |
| 192 | + with contextlib.suppress(AttributeError): |
| 193 | + tags = func._lazy_xp_function # pylint: disable=protected-access |
| 194 | + if tags is None: |
| 195 | + with contextlib.suppress(KeyError, TypeError): |
| 196 | + tags = _ufuncs_tags[func] |
| 197 | + if tags is not None: |
| 198 | + yield name, func, tags |
| 199 | + |
| 200 | + if is_dask_namespace(xp): |
| 201 | + for name, func, tags in iter_tagged(): |
| 202 | + n = tags["allow_dask_compute"] |
| 203 | + wrapped = _allow_dask_compute(func, n) |
| 204 | + monkeypatch.setitem(globals_, name, wrapped) |
189 | 205 |
|
190 | 206 | elif is_jax_namespace(xp):
|
191 | 207 | import jax
|
192 | 208 |
|
193 |
| - for name, func in globals_.items(): |
194 |
| - kwargs = cast( # type: ignore[no-any-explicit] |
195 |
| - "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) |
196 |
| - ) |
197 |
| - if kwargs is not None: |
| 209 | + for name, func, tags in iter_tagged(): |
| 210 | + if tags["jax_jit"]: |
198 | 211 | # suppress unused-ignore to run mypy in -e lint as well as -e dev
|
199 |
| - wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] |
| 212 | + wrapped = cast( # type: ignore[no-any-explicit] |
| 213 | + Callable[..., Any], |
| 214 | + jax.jit( |
| 215 | + func, |
| 216 | + static_argnums=tags["static_argnums"], |
| 217 | + static_argnames=tags["static_argnames"], |
| 218 | + ), |
| 219 | + ) |
200 | 220 | monkeypatch.setitem(globals_, name, wrapped)
|
201 | 221 |
|
202 | 222 |
|
|
0 commit comments