Skip to content

Commit aeab691

Browse files
stheerthafedjax authors
authored andcommitted
Adafactor optimizer
PiperOrigin-RevId: 486362494
1 parent 05e4849 commit aeab691

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

fedjax/core/optimizers.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414
"""Lightweight library for working with optimizers."""
1515

16-
from typing import Callable, List, Optional, Tuple, Union
16+
from typing import Any, Callable, List, Optional, Tuple, Union
1717

1818
from fedjax.core import dataclasses
1919
from fedjax.core.typing import OptState
2020
from fedjax.core.typing import Params
2121

2222
import haiku as hk
2323
import jax
24+
import jax.numpy as jnp
2425
import optax
2526

2627
Grads = Params
@@ -278,3 +279,70 @@ def yogi(
278279
"""
279280
return create_optimizer_from_optax(
280281
optax.yogi(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps))
282+
283+
284+
def adafactor(
285+
learning_rate: ScalarOrSchedule,
286+
min_dim_size_to_factor: int = 128,
287+
decay_rate: float = 0.8,
288+
decay_offset: int = 0,
289+
multiply_by_parameter_scale: float = True,
290+
clipping_threshold: Optional[float] = 1.0,
291+
momentum: Optional[float] = None,
292+
dtype_momentum: Any = jnp.float32,
293+
weight_decay_rate: Optional[float] = None,
294+
eps: float = 1e-30,
295+
factored: bool = True,
296+
weight_decay_mask: Optional[Any] = None,
297+
) -> Optimizer:
298+
"""The Adafactor optimizer.
299+
300+
Adafactor is an adaptive learning rate optimizer that focuses on fast
301+
training of large scale neural networks. It saves memory by using a factored
302+
estimate of the second order moments used to scale gradients.
303+
304+
References:
305+
[Shazeer and Stern, 2018] (https://arxiv.org/abs/1804.04235)
306+
307+
Args:
308+
learning_rate: A fixed global scaling factor. Note: the natural scale for
309+
Adafactor's LR is markedly different from Adam, one doesn't use the
310+
1/sqrt(hidden) correction for this optim with attention-based models.
311+
min_dim_size_to_factor: Only factor the statistics if two array dimensions
312+
have at least this size.
313+
decay_rate: Controls second-moment exponential decay schedule.
314+
decay_offset: For fine-tuning, one may set this to the starting step
315+
number of the fine-tuning phase.
316+
multiply_by_parameter_scale: If True, then scale learning_rate by
317+
parameter norm. If False, provided learning_rate is absolute step size.
318+
clipping_threshold: Optional clipping threshold. Must be >= 1. If None,
319+
clipping is disabled.
320+
momentum: Optional value between 0 and 1, enables momentum and uses extra
321+
memory if non-None! None by default.
322+
dtype_momentum: Data type of momentum buffers.
323+
weight_decay_rate: Optional rate at which to decay weights.
324+
eps: Regularization constant for root mean squared gradient.
325+
factored: Whether to use factored second-moment estimates.
326+
weight_decay_mask: A tree with same structure as (or a prefix of) the
327+
params PyTree, or a Callable that returns such a pytree given the
328+
params/updates. The leaves should be booleans, `True` for
329+
leaves/subtrees you want to apply the transformation to, and `False` for
330+
those you want to skip.
331+
332+
Returns:
333+
The corresponding `Optimizer`.
334+
"""
335+
return create_optimizer_from_optax(
336+
optax.adafactor(
337+
learning_rate=learning_rate,
338+
min_dim_size_to_factor=min_dim_size_to_factor,
339+
decay_rate=decay_rate,
340+
decay_offset=decay_offset,
341+
multiply_by_parameter_scale=multiply_by_parameter_scale,
342+
clipping_threshold=clipping_threshold,
343+
momentum=momentum,
344+
dtype_momentum=dtype_momentum,
345+
weight_decay_rate=weight_decay_rate,
346+
eps=eps,
347+
factored=factored,
348+
weight_decay_mask=weight_decay_mask))

0 commit comments

Comments
 (0)