Skip to content

Commit bcb4b5b

Browse files
committed
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. **What does this PR do?** 1. Introduces `ModelSpec` to describe a model and how to parallelize it. 2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module. 3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec. 4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path. **Why do we need this PR?** This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase. **Next steps** 1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler. 2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed. ghstack-source-id: 9ed1b54aa945af27ce0881ea02150c9e2f0022e8 Pull Request resolved: #814
1 parent 690f299 commit bcb4b5b

15 files changed

+458
-172
lines changed

torchtitan/checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.utils.data import DataLoader
2929
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3030
from torchtitan.logging import init_logger, logger
31-
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
31+
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
3232

3333

3434
class IntervalType(enum.Enum):
@@ -140,7 +140,7 @@ def __init__(
140140
dataloader: DataLoader,
141141
model_parts: List[nn.Module],
142142
optimizers: OptimizersContainer,
143-
lr_schedulers: SchedulersContainer,
143+
lr_schedulers: LRSchedulersContainer,
144144
states: Dict[str, Any],
145145
job_config: JobConfig,
146146
) -> None:

torchtitan/config_manager.py

+20
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,26 @@ def __init__(self):
375375
The default value is 'allgather'.
376376
""",
377377
)
378+
# I'm not particularly fond of this. Users can choose to write their own wrapper
379+
# module and import TorchTitan training loop and execute it, which look cleaner.
380+
# One reason to provide this option is to allow users to use the existing run script.
381+
# While the script is pretty trivial now, we may add more logic when integrating
382+
# with TorchFT.
383+
# This option is subject to change and may be deleted in the future.
384+
self.parser.add_argument(
385+
"--experimental.custom_model_path",
386+
type=str,
387+
default="",
388+
help="""
389+
The --custom_model_path option allows to specify a custom path to a model module
390+
391+
that is not natively implemented within TorchTitan.
392+
393+
Acceptable values are the file system path to the module (e.g., my_models/model_x)
394+
395+
dotted import module (e.g., some_package.model_x).
396+
""",
397+
)
378398
self.parser.add_argument(
379399
"--training.mixed_precision_param",
380400
type=str,

torchtitan/model_spec.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
10+
from dataclasses import dataclass
11+
from typing import Callable, Dict, List, Protocol, Tuple, Type
12+
13+
import torch.nn as nn
14+
from torch.distributed.pipelining.schedules import _PipelineSchedule
15+
from torchtitan.config_manager import JobConfig
16+
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer
17+
18+
19+
@dataclass
20+
class BaseModelArgs:
21+
"""All ModelArgs should inherit from this class.
22+
23+
The only usage of this class is type checking but allows us to extend common
24+
arguments to all models in the future.
25+
"""
26+
27+
_enforced: str = "This field is used to enforce all fields have defaults."
28+
29+
30+
class ModelProtocol(Protocol):
31+
"""Defines the interface for a model class.
32+
33+
This is used to enforce that all model classes have some methods that are
34+
required by the TorchTitan trainer.
35+
"""
36+
37+
@staticmethod
38+
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ...
39+
40+
41+
@dataclass
42+
class ModelSpec:
43+
name: str
44+
cls: Type[nn.Module]
45+
config: Dict[str, BaseModelArgs]
46+
# TODO: Add a ``build_dataloader_fn``
47+
# As for now, this is a string. So it will have to be built-in to the
48+
# TorchTitan library. A better way would be to have a dataloader class
49+
# and a ``build_dataloader`` function that take job_config to consume
50+
# the different dataloader and tokenizer configs.
51+
tokenizer: str
52+
parallelize_fn: Callable[[nn.Module], None]
53+
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
54+
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer]
55+
build_lr_schedulers_fn: Callable[
56+
[List[nn.Module], JobConfig], LRSchedulersContainer
57+
]
58+
59+
# TODO: Add a FQN convert fn to allow users to load checkpoints from
60+
# HuggingFace or other sources that have different FQN conventions.
61+
62+
63+
_model_specs = {}
64+
65+
66+
def register_model_spec(model_spec: ModelSpec) -> None:
67+
global _model_specs
68+
if model_spec.name in _model_specs:
69+
raise ValueError(f"Model {model_spec.name} is already registered.")
70+
_model_specs[model_spec.name] = model_spec
71+
72+
73+
def get_model_spec(name: str) -> ModelSpec:
74+
global _model_specs
75+
if name not in _model_specs:
76+
raise ValueError(f"Model {name} is not registered.")
77+
return _model_specs[name]

torchtitan/models/__init__.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.models.llama import llama3_configs, Transformer
87

9-
models_config = {
10-
"llama3": llama3_configs,
11-
}
12-
13-
model_name_to_cls = {"llama3": Transformer}
14-
15-
model_name_to_tokenizer = {
16-
"llama3": "tiktoken",
17-
}
8+
# Import the built-in models here so that the corresponding register_model_spec()
9+
# will be called.
10+
import torchtitan.models.llama # noqa

torchtitan/models/llama/__init__.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
from torchtitan.model_spec import ModelSpec, register_model_spec
910
from torchtitan.models.llama.model import ModelArgs, Transformer
11+
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
12+
13+
from .parallelize_llama import parallelize_llama
14+
from .pipeline_llama import pipeline_llama
15+
16+
__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]
1017

11-
__all__ = ["Transformer"]
1218

1319
llama3_configs = {
1420
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
@@ -40,3 +46,17 @@
4046
rope_theta=500000,
4147
),
4248
}
49+
50+
51+
register_model_spec(
52+
ModelSpec(
53+
name="llama3",
54+
cls=Transformer,
55+
config=llama3_configs,
56+
tokenizer="tiktoken",
57+
parallelize_fn=parallelize_llama,
58+
pipelining_fn=pipeline_llama,
59+
build_optimizers_fn=build_optimizers,
60+
build_lr_schedulers_fn=build_lr_schedulers,
61+
)
62+
)

torchtitan/models/llama/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
import torch
1414
import torch.nn.functional as F
1515
from torch import nn
16+
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
1617
from torchtitan.models.norms import build_norm
1718

1819

1920
@dataclass
20-
class ModelArgs:
21+
class ModelArgs(BaseModelArgs):
2122
dim: int = 4096
2223
n_layers: int = 32
2324
n_heads: int = 32
@@ -258,7 +259,7 @@ def init_weights(self, init_std: float):
258259
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
259260

260261

261-
class TransformerBlock(nn.Module):
262+
class TransformerBlock(nn.Module, ModelProtocol):
262263
"""
263264
TransformerBlock Module
264265

torchtitan/parallelisms/parallelize_llama.py torchtitan/models/llama/parallelize_llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
36-
from torchtitan.parallelisms.parallel_dims import ParallelDims
36+
from torchtitan.parallelisms import ParallelDims
3737

3838

3939
def parallelize_llama(

torchtitan/parallelisms/pipeline_llama.py torchtitan/models/llama/pipeline_llama.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414
from torch.distributed import DeviceMesh
1515
from torch.distributed.pipelining import PipelineStage
1616

17+
from torch.distributed.pipelining.schedules import _PipelineSchedule
18+
1719
from torchtitan.config_manager import JobConfig
1820
from torchtitan.logging import logger
19-
from torchtitan.models.llama.model import ModelArgs
20-
from torchtitan.parallelisms.parallel_dims import ParallelDims
21-
from torchtitan.parallelisms.pipelining_utils import (
21+
from torchtitan.parallelisms import (
2222
build_pipeline_schedule,
2323
generate_split_points,
24+
ParallelDims,
2425
stage_ids_this_rank,
2526
)
2627

28+
from .model import ModelArgs
29+
2730

2831
DeviceType = Union[int, str, torch.device]
2932

@@ -36,7 +39,7 @@ def pipeline_llama(
3639
device: DeviceType,
3740
model_config: ModelArgs,
3841
loss_fn: Callable[..., torch.Tensor],
39-
):
42+
) -> tuple[_PipelineSchedule, list[nn.Module]]:
4043
stages, models = pipeline_llama_manual_split(
4144
model, pp_mesh, parallel_dims, job_config, device, model_config
4245
)
@@ -53,7 +56,7 @@ def pipeline_llama_manual_split(
5356
job_config: JobConfig,
5457
device: DeviceType,
5558
model_config: ModelArgs,
56-
):
59+
) -> tuple[list[PipelineStage], list[nn.Module]]:
5760
"""
5861
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
5962
@@ -67,10 +70,16 @@ def pipeline_llama_manual_split(
6770

6871
splits = (
6972
job_config.experimental.pipeline_parallel_split_points
70-
or generate_split_points(job_config, parallel_dims.pp, model_config)
73+
or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers)
7174
)
7275

73-
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
76+
def _build_stage(
77+
stage_idx: int,
78+
start_layer: Optional[str],
79+
stop_layer: Optional[str],
80+
is_first: bool = False,
81+
is_last: bool = False,
82+
) -> tuple[PipelineStage, nn.Module]:
7483
model = copy.deepcopy(whole_model)
7584
if not is_first:
7685
model.tok_embeddings = None

0 commit comments

Comments
 (0)