@@ -18,7 +18,7 @@ def __init__(self, params=None):
18
18
def __call__ (self , updates , params = None ):
19
19
raise NotImplementedError ("A subclass must implement this method" )
20
20
21
- def step (self , grads , params , all_finite : Optional [bool ] = None ):
21
+ def step (self , grads , params , all_finite : Optional [jnp . ndarray ] = None ):
22
22
"""An optimizing step.
23
23
24
24
First, transform gradients
@@ -291,15 +291,36 @@ def __call__(self, updates, params=None):
291
291
def chain (* fs : Callable [[Any ], GradientTransformation ]):
292
292
class Chain (GradientTransformation ):
293
293
transforms : Sequence [GradientTransformation ]
294
+ flatten : bool
294
295
295
- def __init__ (self , params ):
296
+ def __init__ (self , params , flatten : bool = False ):
297
+ """Create a chain of gradient transformations.
298
+
299
+ Arguments:
300
+ params: trainable parameters.
301
+ flatten: flatten trainable parameters to a list for faster speed in jit mode.
302
+ """
296
303
super ().__init__ ()
297
- transforms = [f (params ) for f in fs ]
304
+ self .flatten = flatten
305
+ if flatten :
306
+ leaves = jax .tree_leaves (params )
307
+ transforms = [f (leaves ) for f in fs ]
308
+ else :
309
+ transforms = [f (params ) for f in fs ]
298
310
self .register_module_subtree ("transforms" , transforms )
299
311
300
312
def __call__ (self , updates , params = None ):
301
- for f in self .transforms :
302
- updates = f (updates = updates , params = params )
313
+ if self .flatten :
314
+ updates_leaves , updates_treedef = jax .tree_flatten (updates )
315
+ params_leaves = jax .tree_leaves (params )
316
+
317
+ for f in self .transforms :
318
+ updates_leaves = f (updates = updates_leaves , params = params_leaves )
319
+
320
+ updates = jax .tree_unflatten (updates_treedef , updates_leaves )
321
+ else :
322
+ for f in self .transforms :
323
+ updates = f (updates = updates , params = params )
303
324
304
325
return updates
305
326
0 commit comments