19
19
20
20
from jax._src import linear_util as lu
21
21
22
- wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
22
+ # Produce a WrappedFun for applying transformations on `f`
23
+ wf = lu.wrap_init(f, debug_info=api_util.debug_info("test", f, (), {}))
23
24
24
25
A `WrappedFun` object represents a function `f`, together with a sequence of
25
26
nested transformations that are to be applied to the positional and keyword
@@ -66,6 +67,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
66
67
from collections .abc import Callable , Sequence
67
68
from functools import partial
68
69
from typing import Any , NamedTuple
70
+ import warnings
69
71
import weakref
70
72
71
73
from jax ._src import config
@@ -150,14 +152,15 @@ class WrappedFun:
150
152
stores: a list of out_store for the auxiliary output of the `transforms`.
151
153
params: extra parameters to pass as keyword arguments to `f`, along with the
152
154
transformed keyword arguments.
155
+ debug_info: debugging info about the function being wrapped.
153
156
"""
154
157
__slots__ = ("f" , "f_transformed" , "transforms" , "stores" , "params" , "in_type" , "debug_info" )
155
158
156
159
def __init__ (self , f : Callable ,
157
160
f_transformed : Callable ,
158
161
transforms ,
159
162
stores : tuple [Store | EqualStore | None , ...], params , in_type ,
160
- debug_info : DebugInfo | None ):
163
+ debug_info : DebugInfo ):
161
164
self .f = f
162
165
self .f_transformed = f_transformed
163
166
self .transforms = transforms
@@ -168,7 +171,7 @@ def __init__(self, f: Callable,
168
171
169
172
@property
170
173
def __name__ (self ):
171
- return getattr ( self .f , '__name__' , '<unnamed wrapped function>' )
174
+ return self .debug_info . func_name
172
175
173
176
def wrap (self , gen , gen_static_args ,
174
177
out_store : Store | EqualStore | None ) -> WrappedFun :
@@ -254,6 +257,7 @@ def fun_name(f):
254
257
except :
255
258
return str (f )
256
259
260
+
257
261
class DebugInfo (NamedTuple ):
258
262
"""Debugging info about a func, its arguments, and results."""
259
263
traced_for : str # e.g. 'jit', 'scan', etc
@@ -313,19 +317,27 @@ def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
313
317
return tuple (v for v , b in zip (self .safe_result_paths (len (keep )), keep ) if b )
314
318
315
319
320
+ def _missing_debug_info_msg ():
321
+ warnings .warn (
322
+ "linear_util.wrap_init() or core.Jaxpr are called without a DebugInfo "
323
+ "object. This behavior is deprecated, use api_util.debug_info() to "
324
+ "construct a proper DebugInfo object." ,
325
+ DeprecationWarning , stacklevel = 2 )
326
+ return DebugInfo ("missing_debug_info" , fun_name (f ), (), {})
327
+
328
+
316
329
def wrap_init (f : Callable , params = None , * ,
317
- debug_info : DebugInfo | None = None ) -> WrappedFun :
330
+ debug_info : DebugInfo ) -> WrappedFun :
318
331
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
319
332
params_dict = {} if params is None else params
320
333
params = () if params is None else tuple (sorted (params .items ()))
321
334
fun = WrappedFun (f , partial (f , ** params_dict ), (), (), params , None , debug_info )
322
- if debug_info :
323
- if debug_info .result_paths is None :
324
- fun , result_paths_thunk = _get_result_paths_thunk (fun )
325
- debug_info = debug_info ._replace (
326
- result_paths = HashableFunction (result_paths_thunk , closure = ()))
327
- fun = WrappedFun (fun .f , fun .f_transformed , fun .transforms , fun .stores ,
328
- fun .params , fun .in_type , debug_info )
335
+ if debug_info .result_paths is None :
336
+ fun , result_paths_thunk = _get_result_paths_thunk (fun )
337
+ debug_info = debug_info ._replace (
338
+ result_paths = HashableFunction (result_paths_thunk , closure = ()))
339
+ fun = WrappedFun (fun .f , fun .f_transformed , fun .transforms , fun .stores ,
340
+ fun .params , fun .in_type , debug_info )
329
341
return fun
330
342
331
343
0 commit comments