Skip to content

Commit 1e46081

Browse files
author
NTT123
committed
Add flatten mode to improve tree_flatten and tree_unflatten speed when needed.
1 parent 2ca2ee3 commit 1e46081

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

Diff for: opax/transform.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, params=None):
1818
def __call__(self, updates, params=None):
1919
raise NotImplementedError("A subclass must implement this method")
2020

21-
def step(self, grads, params, all_finite: Optional[bool] = None):
21+
def step(self, grads, params, all_finite: Optional[jnp.ndarray] = None):
2222
"""An optimizing step.
2323
2424
First, transform gradients
@@ -291,15 +291,36 @@ def __call__(self, updates, params=None):
291291
def chain(*fs: Callable[[Any], GradientTransformation]):
292292
class Chain(GradientTransformation):
293293
transforms: Sequence[GradientTransformation]
294+
flatten: bool
294295

295-
def __init__(self, params):
296+
def __init__(self, params, flatten: bool = False):
297+
"""Create a chain of gradient transformations.
298+
299+
Arguments:
300+
params: trainable parameters.
301+
flatten: flatten trainable parameters to a list for faster speed in jit mode.
302+
"""
296303
super().__init__()
297-
transforms = [f(params) for f in fs]
304+
self.flatten = flatten
305+
if flatten:
306+
leaves = jax.tree_leaves(params)
307+
transforms = [f(leaves) for f in fs]
308+
else:
309+
transforms = [f(params) for f in fs]
298310
self.register_module_subtree("transforms", transforms)
299311

300312
def __call__(self, updates, params=None):
301-
for f in self.transforms:
302-
updates = f(updates=updates, params=params)
313+
if self.flatten:
314+
updates_leaves, updates_treedef = jax.tree_flatten(updates)
315+
params_leaves = jax.tree_leaves(params)
316+
317+
for f in self.transforms:
318+
updates_leaves = f(updates=updates_leaves, params=params_leaves)
319+
320+
updates = jax.tree_unflatten(updates_treedef, updates_leaves)
321+
else:
322+
for f in self.transforms:
323+
updates = f(updates=updates, params=params)
303324

304325
return updates
305326

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import find_packages, setup
22

3-
__version__ = "0.1.6"
3+
__version__ = "0.1.7"
44
url = "https://github.com/ntt123/opax"
55

66
install_requires = ["pax"]

Diff for: tests/test_opax.py

+17
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,20 @@ def loss_fn(params, model, inputs) -> pax.utils.LossFnOutput:
112112
with pytest.raises(ValueError):
113113
for i in range(10):
114114
loss, net, opt = update_fn(net, opt, (x, x))
115+
116+
117+
def test_train_flatten():
118+
net = pax.nn.Sequential(
119+
pax.nn.Linear(1, 2),
120+
pax.nn.Linear(2, 1),
121+
)
122+
123+
def loss_fn(params, model, inputs) -> pax.utils.LossFnOutput:
124+
loss = jnp.mean(jnp.square(model.update(params)(inputs[0]) - inputs[1]))
125+
return loss, (loss, model)
126+
127+
update_fn = pax.utils.build_update_fn(loss_fn)
128+
x = jnp.zeros((1, 1))
129+
opt = opax.adam()(net.parameters(), flatten=True)
130+
for i in range(10):
131+
loss, net, opt = update_fn(net, opt, (x, x))

0 commit comments

Comments
 (0)