Skip to content

Commit 8e1cbbd

Browse files
committed
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. **What does this PR do?** 1. Introduces `ModelSpec` to describe a model and how to parallelize it. 2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module. 3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec. 4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path. **Why do we need this PR?** This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase. **Next steps** 1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler. 2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed. ghstack-source-id: 28259eb74975eeb7ad790a774b6e719f3aa19a31 Pull Request resolved: #814
1 parent d4c86e3 commit 8e1cbbd

File tree

6 files changed

+157
-23
lines changed

6 files changed

+157
-23
lines changed

torchtitan/config_manager.py

+23
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ def __init__(self):
375375
The default value is 'allgather'.
376376
""",
377377
)
378+
self.parser.add_argument(
379+
"--experimental.custom_model_path",
380+
type=str,
381+
default="",
382+
help="""
383+
The --custom_model_path option allows to specify a custom path to a model module
384+
385+
that is not natively implemented within TorchTitan.
386+
387+
Acceptable values are the file system path to the module (e.g., my_models/model_x)
388+
389+
dotted import module (e.g., some_package.model_x).
390+
"""
391+
)
378392
self.parser.add_argument(
379393
"--training.mixed_precision_param",
380394
type=str,
@@ -638,6 +652,15 @@ def parse_args(self, args_list: list = sys.argv[1:]):
638652
exp["pipeline_parallel_split_points"]
639653
)
640654

655+
if (
656+
"experimental" in args_dict
657+
and "model_module_path" in args_dict["experimental"]
658+
and args_dict["experimental"]["model_module_path"]
659+
):
660+
from torchtitan.models import add_model_spec_path
661+
662+
add_model_spec_path(args_dict["experimental"]["model_module_path"])
663+
641664
# override args dict with cmd_args
642665
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
643666
for section, section_args in cmd_args_dict.items():

torchtitan/model_spec.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from dataclasses import dataclass
2+
from typing import Callable, Dict, List, Protocol, Tuple, Type
3+
4+
import torch.nn as nn
5+
from torch.distributed.pipelining.schedules import _PipelineSchedule
6+
7+
@dataclass
8+
class BaseModelArgs:
9+
_enforced: str = "This field is used to enforce all fields have defaults."
10+
11+
12+
class ModelProtocol(Protocol):
13+
def from_model_args(self, args: BaseModelArgs) -> nn.Module:
14+
...
15+
16+
17+
@dataclass
18+
class ModelSpec:
19+
name: str
20+
cls: Type[nn.Module]
21+
config: Dict[str, BaseModelArgs]
22+
# As for now, this is a string. So it will have to be built-in to the
23+
# TorchTitan library. In the future, we can make this a defined class
24+
# that can be extended like ModelSpec.
25+
tokenizer: str
26+
parallelize_fn: Callable[[nn.Module], None]
27+
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]

torchtitan/models/__init__.py

+79-8
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,85 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.models.llama import llama3_configs, Transformer
7+
import importlib
88

9-
models_config = {
10-
"llama3": llama3_configs,
11-
}
9+
import os
10+
import pkgutil
11+
from typing import Dict, Set
1212

13-
model_name_to_cls = {"llama3": Transformer}
13+
import torchtitan.models as models
14+
from torchtitan.model_spec import ModelSpec
1415

15-
model_name_to_tokenizer = {
16-
"llama3": "tiktoken",
17-
}
16+
17+
_model_specs_path: Set[str] = set()
18+
19+
20+
def _load_module(path: str):
21+
path = os.path.expanduser(path)
22+
23+
# 1. Check if path is an existing file or directory path.
24+
if os.path.exists(path):
25+
if os.path.isdir(path):
26+
init_file = os.path.join(path, "__init__.py")
27+
if os.path.isfile(init_file):
28+
return _load_module_from_init(path)
29+
30+
raise ImportError(
31+
f"Directory '{path}' is not a Python package because it does not "
32+
"contain an __init__.py file."
33+
)
34+
else:
35+
raise ImportError(f"Path '{path}' is not a directory.")
36+
37+
# 2. If not a valid path, assume it's a dotted module name.
38+
return importlib.import_module(path)
39+
40+
41+
def _load_module_from_init(path: str):
42+
module_name = os.path.basename(os.path.normpath(path))
43+
init_file = os.path.join(path, "__init__.py")
44+
45+
spec = importlib.util.spec_from_file_location(module_name, init_file)
46+
if spec is None:
47+
raise ImportError(f"Could not create spec from '{init_file}'")
48+
49+
module = importlib.util.module_from_spec(spec)
50+
spec.loader.exec_module(module)
51+
return module
52+
53+
54+
for _, name, _ in pkgutil.iter_modules(models.__path__):
55+
full_module_name = f"{models.__name__}.{name}"
56+
_model_specs_path.add(full_module_name)
57+
# model_module = importlib.import_module(full_module_name)
58+
# load_spec_from_module(model_module)
59+
60+
61+
def add_model_spec_path(path: str):
62+
global _model_specs_path
63+
_model_specs_path.add(path)
64+
65+
66+
def build_model_specs() -> Dict[str, ModelSpec]:
67+
"""
68+
Load all model specs from the `models` package.
69+
"""
70+
global _model_specs_path
71+
model_specs = {}
72+
for path in _model_specs_path:
73+
module = _load_module(path)
74+
model_spec = getattr(module, "model_spec", None)
75+
if model_spec is not None:
76+
model_specs[model_spec.name] = model_spec
77+
# We would like to just use `model_spec` but current torchtitan parallelize
78+
# functions depend on ModelArgs and can cause circular imports.
79+
# As a result, we have to use `build_model_spec` as a workaround.
80+
build_model_spec = getattr(module, "build_model_spec", None)
81+
if build_model_spec:
82+
model_spec = build_model_spec()
83+
model_specs[model_spec.name] = model_spec
84+
85+
return model_specs
86+
87+
88+
__all__ = [add_model_spec_path, build_model_specs]

torchtitan/models/llama/__init__.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
from torchtitan.model_spec import ModelSpec
910
from torchtitan.models.llama.model import ModelArgs, Transformer
1011

11-
__all__ = ["Transformer"]
1212

1313
llama3_configs = {
1414
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
@@ -40,3 +40,18 @@
4040
rope_theta=500000,
4141
),
4242
}
43+
44+
45+
def build_model_spec() -> ModelSpec:
46+
# Avoid circular import
47+
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
48+
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
49+
50+
return ModelSpec(
51+
name="llama3",
52+
cls=Transformer,
53+
config=llama3_configs,
54+
tokenizer="tiktoken",
55+
parallelize_fn=parallelize_llama,
56+
pipelining_fn=pipeline_llama,
57+
)

torchtitan/models/llama/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
import torch
1414
import torch.nn.functional as F
1515
from torch import nn
16+
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
1617
from torchtitan.models.norms import build_norm
1718

1819

1920
@dataclass
20-
class ModelArgs:
21+
class ModelArgs(BaseModelArgs):
2122
dim: int = 4096
2223
n_layers: int = 32
2324
n_heads: int = 32
@@ -258,7 +259,7 @@ def init_weights(self, init_std: float):
258259
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
259260

260261

261-
class TransformerBlock(nn.Module):
262+
class TransformerBlock(nn.Module, ModelProtocol):
262263
"""
263264
TransformerBlock Module
264265

train.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@
1919
from torchtitan.float8 import Float8Handler
2020
from torchtitan.logging import init_logger, logger
2121
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
22-
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
22+
from torchtitan.models import build_model_specs
2323
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
24-
from torchtitan.parallelisms import (
25-
models_parallelize_fns,
26-
models_pipelining_fns,
27-
ParallelDims,
28-
)
24+
from torchtitan.parallelisms import ParallelDims
2925
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
3026
from torchtitan.utils import device_module, device_type
3127

@@ -80,9 +76,10 @@ def main(job_config: JobConfig):
8076
world_mesh, device, job_config.training.seed, job_config.training.deterministic
8177
)
8278
model_name = job_config.model.name
79+
model_spec = build_model_specs()[model_name]
8380

8481
# build tokenizer
85-
tokenizer_type = model_name_to_tokenizer[model_name]
82+
tokenizer_type = model_spec.tokenizer
8683
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
8784
# build dataloader
8885
data_loader = build_hf_data_loader(
@@ -96,8 +93,8 @@ def main(job_config: JobConfig):
9693
)
9794

9895
# build model (using meta init)
99-
model_cls = model_name_to_cls[model_name]
100-
model_config = models_config[model_name][job_config.model.flavor]
96+
model_cls = model_spec.cls
97+
model_config = model_spec.config[job_config.model.flavor]
10198
# set the model configs from training inputs:
10299
# 1. norm type to decide which norm layer to use
103100
# 2. vocab size from tokenizer
@@ -151,7 +148,7 @@ def loss_fn(pred, labels):
151148
# apply parallelisms and initialization
152149
if parallel_dims.pp_enabled:
153150
# apply PT-D Pipeline Parallel
154-
pp_schedule, model_parts = models_pipelining_fns[model_name](
151+
pp_schedule, model_parts = model_spec.pipelining_fn(
155152
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
156153
)
157154
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
@@ -162,14 +159,14 @@ def loss_fn(pred, labels):
162159
# optimizer, and checkpointing
163160
for m in model_parts:
164161
# apply SPMD-style PT-D techniques
165-
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
162+
model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
166163
m.to_empty(device=init_device)
167164
with torch.no_grad():
168165
m.init_weights(buffer_device=buffer_device)
169166
m.train()
170167
else:
171168
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
172-
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
169+
model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
173170
model.to_empty(device=init_device)
174171
with torch.no_grad():
175172
model.init_weights(buffer_device=buffer_device)

0 commit comments

Comments
 (0)