|
13 | 13 | # limitations under the License.
|
14 | 14 | """Lightweight library for working with optimizers."""
|
15 | 15 |
|
16 |
| -from typing import Callable, List, Optional, Tuple, Union |
| 16 | +from typing import Any, Callable, List, Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | from fedjax.core import dataclasses
|
19 | 19 | from fedjax.core.typing import OptState
|
20 | 20 | from fedjax.core.typing import Params
|
21 | 21 |
|
22 | 22 | import haiku as hk
|
23 | 23 | import jax
|
| 24 | +import jax.numpy as jnp |
24 | 25 | import optax
|
25 | 26 |
|
26 | 27 | Grads = Params
|
@@ -278,3 +279,70 @@ def yogi(
|
278 | 279 | """
|
279 | 280 | return create_optimizer_from_optax(
|
280 | 281 | optax.yogi(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps))
|
| 282 | + |
| 283 | + |
| 284 | +def adafactor( |
| 285 | + learning_rate: ScalarOrSchedule, |
| 286 | + min_dim_size_to_factor: int = 128, |
| 287 | + decay_rate: float = 0.8, |
| 288 | + decay_offset: int = 0, |
| 289 | + multiply_by_parameter_scale: float = True, |
| 290 | + clipping_threshold: Optional[float] = 1.0, |
| 291 | + momentum: Optional[float] = None, |
| 292 | + dtype_momentum: Any = jnp.float32, |
| 293 | + weight_decay_rate: Optional[float] = None, |
| 294 | + eps: float = 1e-30, |
| 295 | + factored: bool = True, |
| 296 | + weight_decay_mask: Optional[Any] = None, |
| 297 | +) -> Optimizer: |
| 298 | + """The Adafactor optimizer. |
| 299 | +
|
| 300 | + Adafactor is an adaptive learning rate optimizer that focuses on fast |
| 301 | + training of large scale neural networks. It saves memory by using a factored |
| 302 | + estimate of the second order moments used to scale gradients. |
| 303 | +
|
| 304 | + References: |
| 305 | + [Shazeer and Stern, 2018] (https://arxiv.org/abs/1804.04235) |
| 306 | +
|
| 307 | + Args: |
| 308 | + learning_rate: A fixed global scaling factor. Note: the natural scale for |
| 309 | + Adafactor's LR is markedly different from Adam, one doesn't use the |
| 310 | + 1/sqrt(hidden) correction for this optim with attention-based models. |
| 311 | + min_dim_size_to_factor: Only factor the statistics if two array dimensions |
| 312 | + have at least this size. |
| 313 | + decay_rate: Controls second-moment exponential decay schedule. |
| 314 | + decay_offset: For fine-tuning, one may set this to the starting step |
| 315 | + number of the fine-tuning phase. |
| 316 | + multiply_by_parameter_scale: If True, then scale learning_rate by |
| 317 | + parameter norm. If False, provided learning_rate is absolute step size. |
| 318 | + clipping_threshold: Optional clipping threshold. Must be >= 1. If None, |
| 319 | + clipping is disabled. |
| 320 | + momentum: Optional value between 0 and 1, enables momentum and uses extra |
| 321 | + memory if non-None! None by default. |
| 322 | + dtype_momentum: Data type of momentum buffers. |
| 323 | + weight_decay_rate: Optional rate at which to decay weights. |
| 324 | + eps: Regularization constant for root mean squared gradient. |
| 325 | + factored: Whether to use factored second-moment estimates. |
| 326 | + weight_decay_mask: A tree with same structure as (or a prefix of) the |
| 327 | + params PyTree, or a Callable that returns such a pytree given the |
| 328 | + params/updates. The leaves should be booleans, `True` for |
| 329 | + leaves/subtrees you want to apply the transformation to, and `False` for |
| 330 | + those you want to skip. |
| 331 | +
|
| 332 | + Returns: |
| 333 | + The corresponding `Optimizer`. |
| 334 | + """ |
| 335 | + return create_optimizer_from_optax( |
| 336 | + optax.adafactor( |
| 337 | + learning_rate=learning_rate, |
| 338 | + min_dim_size_to_factor=min_dim_size_to_factor, |
| 339 | + decay_rate=decay_rate, |
| 340 | + decay_offset=decay_offset, |
| 341 | + multiply_by_parameter_scale=multiply_by_parameter_scale, |
| 342 | + clipping_threshold=clipping_threshold, |
| 343 | + momentum=momentum, |
| 344 | + dtype_momentum=dtype_momentum, |
| 345 | + weight_decay_rate=weight_decay_rate, |
| 346 | + eps=eps, |
| 347 | + factored=factored, |
| 348 | + weight_decay_mask=weight_decay_mask)) |
0 commit comments