Skip to content

Commit 476f93e

Browse files
committed
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. ghstack-source-id: 9c1d1eb Pull Request resolved: #814
1 parent 5940dde commit 476f93e

19 files changed

+641
-204
lines changed

scripts/estimate/estimation.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from torchtitan.datasets import build_tokenizer
2020
from torchtitan.float8 import Float8Handler
2121
from torchtitan.logging import init_logger, logger
22-
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
22+
from torchtitan.models import model_name_to_tokenizer
2323
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
24-
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
24+
from torchtitan.parallelisms import ParallelDims
25+
from torchtitan.train_spec import get_train_spec
2526

2627

2728
def estimate_memory(job_config: JobConfig):
@@ -74,6 +75,8 @@ def estimate_memory(job_config: JobConfig):
7475
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
7576
)
7677

78+
train_spec = get_train_spec(job_config.model.name)
79+
7780
# build meshes
7881
world_mesh = parallel_dims.build_mesh(device_type="cuda")
7982

@@ -95,8 +98,8 @@ def loss_fn(pred, labels):
9598
)
9699

97100
# build model (using meta init)
98-
model_cls = model_name_to_cls[model_name]
99-
model_config = models_config[model_name][job_config.model.flavor]
101+
model_cls = train_spec.cls
102+
model_config = train_spec.config[job_config.model.flavor]
100103
# set the model configs from training inputs:
101104
# 1. norm type to decide which norm layer to use
102105
# 2. vocab size from tokenizer
@@ -112,7 +115,7 @@ def loss_fn(pred, labels):
112115
):
113116

114117
logger.info(
115-
f"Building {model_name} {job_config.model.flavor} with {model_config}"
118+
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
116119
)
117120
with torch.device("meta"):
118121
model = model_cls.from_model_args(model_config)
@@ -123,7 +126,7 @@ def loss_fn(pred, labels):
123126
float8_handler.convert_to_float8_training(model)
124127

125128
# apply PT-D DP/TP parallelisms and activation checkpointing
126-
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
129+
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
127130

128131
model.to_empty(device="cuda")
129132
if not active_fake_mode():

scripts/generate/test_generate.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
)
2727

2828
from torchtitan import utils
29-
3029
from torchtitan.config_manager import JobConfig
3130
from torchtitan.datasets import build_tokenizer
3231
from torchtitan.logging import init_logger, logger
3332
from torchtitan.metrics import build_device_memory_monitor
34-
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
33+
from torchtitan.models import model_name_to_tokenizer
3534
from torchtitan.parallelisms import ParallelDims
35+
36+
from torchtitan.train_spec import get_train_spec
3637
from torchtitan.utils import device_module, device_type
3738

3839
# support running w/o installing as package
@@ -102,21 +103,21 @@ def test_generate(
102103
device_module.set_device(device)
103104
device_memory_monitor = build_device_memory_monitor()
104105

105-
model_name = config.model.name
106+
train_spec = get_train_spec(config.model.name)
106107

107108
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")
108109

109110
# Tokenizer setup
110111
tokenizer = build_tokenizer(
111-
model_name_to_tokenizer[model_name], config.model.tokenizer_path
112+
model_name_to_tokenizer[train_spec.name], config.model.tokenizer_path
112113
)
113114

114-
model_config = models_config[model_name][config.model.flavor]
115+
model_config = train_spec.config[config.model.flavor]
115116
model_config.norm_type = config.model.norm_type
116117
model_config.max_seq_len = config.training.seq_len
117118
model_config.vocab_size = tokenizer.n_words
118119

119-
model_cls = model_name_to_cls[model_name]
120+
model_cls = train_spec.cls
120121
init_device = "meta" if world_size > 1 else device
121122
with torch.device(init_device):
122123
logger.info(f"Init model on init_device: {init_device}")

tests/unit_tests/test_train_spec.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from functools import partial
8+
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
from torchtitan.config_manager import JobConfig
13+
from torchtitan.models.llama import parallelize_llama, pipeline_llama
14+
from torchtitan.optimizer import (
15+
build_lr_schedulers,
16+
build_optimizers,
17+
OptimizersContainer,
18+
)
19+
from torchtitan.train_spec import (
20+
apply_to_train_specs,
21+
BaseModelArgs,
22+
get_train_spec,
23+
ModelProtocol,
24+
register_train_spec,
25+
TrainSpec,
26+
)
27+
28+
29+
class FakeModel(ModelProtocol):
30+
@staticmethod
31+
def from_model_args(args: BaseModelArgs) -> nn.Module:
32+
return nn.Linear(8, 8)
33+
34+
35+
def fake_build_optimizers(
36+
model_parts: list[nn.Module], job_config: JobConfig
37+
) -> OptimizersContainer:
38+
optimizer_kwargs = {
39+
"lr": 0.1,
40+
"betas": (0.9, 0.95),
41+
"weight_decay": 0.1,
42+
"fused": True,
43+
"foreach": False,
44+
}
45+
return OptimizersContainer(
46+
model_parts=model_parts,
47+
optimizer_kwargs=optimizer_kwargs,
48+
name="Adam",
49+
)
50+
51+
52+
class TestTrainSpec:
53+
def test_register_train_spec(self):
54+
fake_config = {"fake": None}
55+
spec = TrainSpec(
56+
name="fake",
57+
cls=FakeModel,
58+
config=fake_config,
59+
parallelize_fn=parallelize_llama,
60+
pipelining_fn=pipeline_llama,
61+
build_optimizers_fn=build_optimizers,
62+
build_lr_schedulers_fn=build_lr_schedulers,
63+
)
64+
register_train_spec(spec)
65+
new_spec = get_train_spec("fake")
66+
assert new_spec == spec
67+
68+
with pytest.raises(ValueError):
69+
new_spec = get_train_spec("fake2")
70+
71+
def test_optim_hook(self):
72+
fake_config = {"fake": None}
73+
spec = TrainSpec(
74+
name="fake2",
75+
cls=FakeModel,
76+
config=fake_config,
77+
parallelize_fn=parallelize_llama,
78+
pipelining_fn=pipeline_llama,
79+
build_optimizers_fn=fake_build_optimizers,
80+
build_lr_schedulers_fn=build_lr_schedulers,
81+
)
82+
register_train_spec(spec)
83+
new_spec = get_train_spec("fake2")
84+
85+
# Demonstrate how to register a optimizer hook for all model specs
86+
hook_called = False
87+
88+
def my_hook(
89+
optimizer: torch.optim.Optimizer,
90+
args,
91+
kwargs,
92+
model_parts: list[nn.Module],
93+
) -> None:
94+
nonlocal hook_called
95+
hook_called = True
96+
97+
def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec:
98+
# Create a closure to capture the original spec.build_optimizers_fn
99+
original_build_optimizers_fn = spec.build_optimizers_fn
100+
101+
def my_build_optimizer_fn(
102+
model_parts: list[nn.Module], job_config: JobConfig
103+
) -> OptimizersContainer:
104+
optimizers = original_build_optimizers_fn(model_parts, job_config)
105+
optimizers.register_step_post_hook(
106+
partial(my_hook, model_parts=model_parts)
107+
)
108+
return optimizers
109+
110+
spec.build_optimizers_fn = my_build_optimizer_fn
111+
112+
apply_to_train_specs(register_optimizer_hook_to_spec)
113+
114+
model = new_spec.cls.from_model_args(BaseModelArgs())
115+
model_parts = [model]
116+
optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig())
117+
assert optimizers.optimizers[0].__class__.__name__ == "Adam"
118+
batch = torch.randn(8, 8)
119+
model(batch).sum().backward()
120+
assert not hook_called
121+
optimizers.step()
122+
assert hook_called

torchtitan/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
# Import the built-in models here so that the corresponding register_model_spec()
10+
# will be called.
11+
import torchtitan.models # noqa: F401

torchtitan/checkpoint.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
)
2727
from torch.distributed.checkpoint.stateful import Stateful
2828
from torch.utils.data import DataLoader
29+
2930
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3031
from torchtitan.logging import init_logger, logger
31-
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
32+
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
3233

3334

3435
class IntervalType(enum.Enum):
@@ -140,7 +141,7 @@ def __init__(
140141
dataloader: DataLoader,
141142
model_parts: List[nn.Module],
142143
optimizers: OptimizersContainer,
143-
lr_schedulers: SchedulersContainer,
144+
lr_schedulers: LRSchedulersContainer,
144145
states: Dict[str, Any],
145146
job_config: JobConfig,
146147
) -> None:

torchtitan/config_manager.py

+17
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,23 @@ def __init__(self):
393393
The default value is 'allgather'.
394394
""",
395395
)
396+
# I'm not particularly fond of this. Users can choose to write their own wrapper
397+
# module and import TorchTitan training loop and execute it, which look cleaner.
398+
# One reason to provide this option is to allow users to use the existing run script.
399+
# While the script is pretty trivial now, we may add more logic when integrating
400+
# with TorchFT.
401+
# This option is subject to change and may be deleted in the future.
402+
self.parser.add_argument(
403+
"--experimental.custom_model_path",
404+
type=str,
405+
default="",
406+
help="""
407+
The --custom_model_path option allows to specify a custom path to a model module
408+
that is not natively implemented within TorchTitan.
409+
Acceptable values are the file system path to the module (e.g., my_models/model_x)
410+
dotted import module (e.g., some_package.model_x).
411+
""",
412+
)
396413
self.parser.add_argument(
397414
"--training.mixed_precision_param",
398415
type=str,

torchtitan/models/__init__.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.models.llama import llama3_configs, Transformer
87

9-
models_config = {
10-
"llama3": llama3_configs,
11-
}
8+
# Import the built-in models here so that the corresponding register_model_spec()
9+
# will be called.
10+
import torchtitan.models.llama # noqa: F401
1211

13-
model_name_to_cls = {"llama3": Transformer}
1412

15-
model_name_to_tokenizer = {
16-
"llama3": "tiktoken",
17-
}
13+
model_name_to_tokenizer = {"llama3": "tiktoken"}

torchtitan/models/llama/__init__.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,27 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
from torchtitan.models.llama.model import ModelArgs, Transformer
9+
from torchtitan.models.llama.model import Transformer, TransformerModelArgs
10+
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
11+
from torchtitan.train_spec import register_train_spec, TrainSpec
12+
13+
from .parallelize_llama import parallelize_llama
14+
from .pipeline_llama import pipeline_llama
15+
16+
__all__ = [
17+
"parallelize_llama",
18+
"pipeline_llama",
19+
"TransformerModelArgs",
20+
"Transformer",
21+
"llama3_configs",
22+
]
1023

11-
__all__ = ["Transformer"]
1224

1325
llama3_configs = {
14-
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
15-
"8B": ModelArgs(
26+
"debugmodel": TransformerModelArgs(
27+
dim=256, n_layers=8, n_heads=16, rope_theta=500000
28+
),
29+
"8B": TransformerModelArgs(
1630
dim=4096,
1731
n_layers=32,
1832
n_heads=32,
@@ -21,7 +35,7 @@
2135
multiple_of=1024,
2236
rope_theta=500000,
2337
),
24-
"70B": ModelArgs(
38+
"70B": TransformerModelArgs(
2539
dim=8192,
2640
n_layers=80,
2741
n_heads=64,
@@ -30,7 +44,7 @@
3044
multiple_of=4096,
3145
rope_theta=500000,
3246
),
33-
"405B": ModelArgs(
47+
"405B": TransformerModelArgs(
3448
dim=16384,
3549
n_layers=126,
3650
n_heads=128,
@@ -40,3 +54,16 @@
4054
rope_theta=500000,
4155
),
4256
}
57+
58+
59+
register_train_spec(
60+
TrainSpec(
61+
name="llama3",
62+
cls=Transformer,
63+
config=llama3_configs,
64+
parallelize_fn=parallelize_llama,
65+
pipelining_fn=pipeline_llama,
66+
build_optimizers_fn=build_optimizers,
67+
build_lr_schedulers_fn=build_lr_schedulers,
68+
)
69+
)

0 commit comments

Comments
 (0)