-
Notifications
You must be signed in to change notification settings - Fork 328
/
Copy pathtest_train_spec.py
122 lines (105 loc) · 3.68 KB
/
test_train_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import pytest
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.models.llama import parallelize_llama, pipeline_llama
from torchtitan.optimizer import (
build_lr_schedulers,
build_optimizers,
OptimizersContainer,
)
from torchtitan.train_spec import (
apply_to_train_specs,
BaseModelArgs,
get_train_spec,
ModelProtocol,
register_train_spec,
TrainSpec,
)
class FakeModel(ModelProtocol):
@staticmethod
def from_model_args(args: BaseModelArgs) -> nn.Module:
return nn.Linear(8, 8)
def fake_build_optimizers(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizer_kwargs = {
"lr": 0.1,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": True,
"foreach": False,
}
return OptimizersContainer(
model_parts=model_parts,
optimizer_kwargs=optimizer_kwargs,
name="Adam",
)
class TestTrainSpec:
def test_register_train_spec(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake")
assert new_spec == spec
with pytest.raises(ValueError):
new_spec = get_train_spec("fake2")
def test_optim_hook(self):
fake_config = {"fake": None}
spec = TrainSpec(
name="fake2",
cls=FakeModel,
config=fake_config,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=fake_build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
register_train_spec(spec)
new_spec = get_train_spec("fake2")
# Demonstrate how to register a optimizer hook for all model specs
hook_called = False
def my_hook(
optimizer: torch.optim.Optimizer,
args,
kwargs,
model_parts: list[nn.Module],
) -> None:
nonlocal hook_called
hook_called = True
def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
# Create a closure to capture the original spec.build_optimizers_fn
original_build_optimizers_fn = spec.build_optimizers_fn
def my_build_optimizer_fn(
model_parts: list[nn.Module], job_config: JobConfig
) -> OptimizersContainer:
optimizers = original_build_optimizers_fn(model_parts, job_config)
optimizers.register_step_post_hook(
partial(my_hook, model_parts=model_parts)
)
return optimizers
spec.build_optimizers_fn = my_build_optimizer_fn
apply_to_train_specs(register_optimizer_hook_to_spec)
model = new_spec.cls.from_model_args(BaseModelArgs())
model_parts = [model]
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
batch = torch.randn(8, 8)
model(batch).sum().backward()
assert not hook_called
optimizers.step()
assert hook_called