|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import pytest |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +from torchtitan.config_manager import JobConfig |
| 13 | +from torchtitan.models.llama import parallelize_llama, pipeline_llama |
| 14 | +from torchtitan.optimizer import ( |
| 15 | + build_lr_schedulers, |
| 16 | + build_optimizers, |
| 17 | + OptimizersContainer, |
| 18 | +) |
| 19 | +from torchtitan.train_spec import ( |
| 20 | + apply_to_train_specs, |
| 21 | + BaseModelArgs, |
| 22 | + get_train_spec, |
| 23 | + ModelProtocol, |
| 24 | + register_train_spec, |
| 25 | + TrainSpec, |
| 26 | +) |
| 27 | + |
| 28 | + |
| 29 | +class FakeModel(ModelProtocol): |
| 30 | + @staticmethod |
| 31 | + def from_model_args(args: BaseModelArgs) -> nn.Module: |
| 32 | + return nn.Linear(8, 8) |
| 33 | + |
| 34 | + |
| 35 | +def fake_build_optimizers( |
| 36 | + model_parts: list[nn.Module], job_config: JobConfig |
| 37 | +) -> OptimizersContainer: |
| 38 | + optimizer_kwargs = { |
| 39 | + "lr": 0.1, |
| 40 | + "betas": (0.9, 0.95), |
| 41 | + "weight_decay": 0.1, |
| 42 | + "fused": True, |
| 43 | + "foreach": False, |
| 44 | + } |
| 45 | + return OptimizersContainer( |
| 46 | + model_parts=model_parts, |
| 47 | + optimizer_kwargs=optimizer_kwargs, |
| 48 | + name="Adam", |
| 49 | + ) |
| 50 | + |
| 51 | + |
| 52 | +class TestTrainSpec: |
| 53 | + def test_register_train_spec(self): |
| 54 | + fake_config = {"fake": None} |
| 55 | + spec = TrainSpec( |
| 56 | + name="fake", |
| 57 | + cls=FakeModel, |
| 58 | + config=fake_config, |
| 59 | + parallelize_fn=parallelize_llama, |
| 60 | + pipelining_fn=pipeline_llama, |
| 61 | + build_optimizers_fn=build_optimizers, |
| 62 | + build_lr_schedulers_fn=build_lr_schedulers, |
| 63 | + ) |
| 64 | + register_train_spec(spec) |
| 65 | + new_spec = get_train_spec("fake") |
| 66 | + assert new_spec == spec |
| 67 | + |
| 68 | + with pytest.raises(ValueError): |
| 69 | + new_spec = get_train_spec("fake2") |
| 70 | + |
| 71 | + def test_optim_hook(self): |
| 72 | + fake_config = {"fake": None} |
| 73 | + spec = TrainSpec( |
| 74 | + name="fake2", |
| 75 | + cls=FakeModel, |
| 76 | + config=fake_config, |
| 77 | + parallelize_fn=parallelize_llama, |
| 78 | + pipelining_fn=pipeline_llama, |
| 79 | + build_optimizers_fn=fake_build_optimizers, |
| 80 | + build_lr_schedulers_fn=build_lr_schedulers, |
| 81 | + ) |
| 82 | + register_train_spec(spec) |
| 83 | + new_spec = get_train_spec("fake2") |
| 84 | + |
| 85 | + # Demonstrate how to register a optimizer hook for all model specs |
| 86 | + hook_called = False |
| 87 | + |
| 88 | + def my_hook( |
| 89 | + optimizer: torch.optim.Optimizer, |
| 90 | + args, |
| 91 | + kwargs, |
| 92 | + model_parts: list[nn.Module], |
| 93 | + ) -> None: |
| 94 | + nonlocal hook_called |
| 95 | + hook_called = True |
| 96 | + |
| 97 | + def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: |
| 98 | + # Create a closure to capture the original spec.build_optimizers_fn |
| 99 | + original_build_optimizers_fn = spec.build_optimizers_fn |
| 100 | + |
| 101 | + def my_build_optimizer_fn( |
| 102 | + model_parts: list[nn.Module], job_config: JobConfig |
| 103 | + ) -> OptimizersContainer: |
| 104 | + optimizers = original_build_optimizers_fn(model_parts, job_config) |
| 105 | + optimizers.register_step_post_hook( |
| 106 | + partial(my_hook, model_parts=model_parts) |
| 107 | + ) |
| 108 | + return optimizers |
| 109 | + |
| 110 | + spec.build_optimizers_fn = my_build_optimizer_fn |
| 111 | + |
| 112 | + apply_to_train_specs(register_optimizer_hook_to_spec) |
| 113 | + |
| 114 | + model = new_spec.cls.from_model_args(BaseModelArgs()) |
| 115 | + model_parts = [model] |
| 116 | + optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) |
| 117 | + assert optimizers.optimizers[0].__class__.__name__ == "Adam" |
| 118 | + batch = torch.randn(8, 8) |
| 119 | + model(batch).sum().backward() |
| 120 | + assert not hook_called |
| 121 | + optimizers.step() |
| 122 | + assert hook_called |
0 commit comments