Skip to content

Commit bf30d7a

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
[better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual debug info. This enables us to clean up the type declarations and to remove the checks whether debug_info is present. For usage outside of the JAX internals, we change `jax.extend.linear_util.wrap_init` to be usable without debug_info, for temporary backwards compatibility. We emit a deprecation warning and fill-in some fake debugging info. PiperOrigin-RevId: 725512692
1 parent 4b1400d commit bf30d7a

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

jax/_src/linear_util.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
2020
from jax._src import linear_util as lu
2121
22-
wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
22+
# Produce a WrappedFun for applying transformations on `f`
23+
wf = lu.wrap_init(f, debug_info=api_util.debug_info("test", f, (), {}))
2324
2425
A `WrappedFun` object represents a function `f`, together with a sequence of
2526
nested transformations that are to be applied to the positional and keyword
@@ -66,6 +67,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
6667
from collections.abc import Callable, Sequence
6768
from functools import partial
6869
from typing import Any, NamedTuple
70+
import warnings
6971
import weakref
7072

7173
from jax._src import config
@@ -150,14 +152,15 @@ class WrappedFun:
150152
stores: a list of out_store for the auxiliary output of the `transforms`.
151153
params: extra parameters to pass as keyword arguments to `f`, along with the
152154
transformed keyword arguments.
155+
debug_info: debugging info about the function being wrapped.
153156
"""
154157
__slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info")
155158

156159
def __init__(self, f: Callable,
157160
f_transformed: Callable,
158161
transforms,
159162
stores: tuple[Store | EqualStore | None, ...], params, in_type,
160-
debug_info: DebugInfo | None):
163+
debug_info: DebugInfo):
161164
self.f = f
162165
self.f_transformed = f_transformed
163166
self.transforms = transforms
@@ -168,7 +171,7 @@ def __init__(self, f: Callable,
168171

169172
@property
170173
def __name__(self):
171-
return getattr(self.f, '__name__', '<unnamed wrapped function>')
174+
return self.debug_info.func_name
172175

173176
def wrap(self, gen, gen_static_args,
174177
out_store: Store | EqualStore | None) -> WrappedFun:
@@ -254,6 +257,7 @@ def fun_name(f):
254257
except:
255258
return str(f)
256259

260+
257261
class DebugInfo(NamedTuple):
258262
"""Debugging info about a func, its arguments, and results."""
259263
traced_for: str # e.g. 'jit', 'scan', etc
@@ -313,19 +317,27 @@ def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
313317
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
314318

315319

320+
def _missing_debug_info_msg():
321+
warnings.warn(
322+
"linear_util.wrap_init() or core.Jaxpr are called without a DebugInfo "
323+
"object. This behavior is deprecated, use api_util.debug_info() to "
324+
"construct a proper DebugInfo object.",
325+
DeprecationWarning, stacklevel=2)
326+
return DebugInfo("missing_debug_info", fun_name(f), (), {})
327+
328+
316329
def wrap_init(f: Callable, params=None, *,
317-
debug_info: DebugInfo | None = None) -> WrappedFun:
330+
debug_info: DebugInfo) -> WrappedFun:
318331
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
319332
params_dict = {} if params is None else params
320333
params = () if params is None else tuple(sorted(params.items()))
321334
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
322-
if debug_info:
323-
if debug_info.result_paths is None:
324-
fun, result_paths_thunk = _get_result_paths_thunk(fun)
325-
debug_info = debug_info._replace(
326-
result_paths=HashableFunction(result_paths_thunk, closure=()))
327-
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
328-
fun.params, fun.in_type, debug_info)
335+
if debug_info.result_paths is None:
336+
fun, result_paths_thunk = _get_result_paths_thunk(fun)
337+
debug_info = debug_info._replace(
338+
result_paths=HashableFunction(result_paths_thunk, closure=()))
339+
fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
340+
fun.params, fun.in_type, debug_info)
329341
return fun
330342

331343

jax/extend/linear_util.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# Note: import <name> as <name> is required for names to be exported.
1616
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
1717

18+
from typing import Callable
19+
1820
from jax._src.linear_util import (
1921
StoreException as StoreException,
2022
WrappedFun as WrappedFun,
@@ -24,5 +26,17 @@
2426
transformation_with_aux as transformation_with_aux,
2527
transformation2 as transformation2,
2628
transformation_with_aux2 as transformation_with_aux2,
27-
wrap_init as wrap_init,
29+
wrap_init as _wrap_init,
30+
_missing_debug_info_msg as _missing_debug_info_msg,
2831
)
32+
33+
# Version of wrap_init that does not require a DebugInfo object.
34+
# This usage is deprecated, use api_util.debug_info() to construct a proper
35+
# DebugInfo object.
36+
def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun:
37+
return _wrap_init(
38+
f, params,
39+
debug_info=_missing_debug_info_msg() if debug_info is None else debug_info)
40+
41+
del _wrap_init
42+
del _missing_debug_info_msg

0 commit comments

Comments
 (0)