-
Notifications
You must be signed in to change notification settings - Fork 338
/
Copy pathtrain_spec.py
115 lines (88 loc) · 3.56 KB
/
train_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
from abc import abstractmethod
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import Protocol, TypeAlias
import torch
import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.ft import FTManager
from torchtitan.components.loss import LossFunction
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.metrics import MetricsProcessor
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
DeviceType = int | str | torch.device
@dataclass
class BaseModelArgs:
"""All ModelArgs should inherit from this class.
The only usage of this class is type checking but allows us to extend common
arguments to all models in the future.
"""
_enforced: str = "This field is used to enforce all fields have defaults."
@abstractmethod
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
pass
@abstractmethod
def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
pass
class ModelProtocol(Protocol):
"""Defines the interface for a model class.
This is used to enforce that all model classes have some methods that are
required by the TorchTitan trainer.
"""
@classmethod
def from_model_args(cls, args: BaseModelArgs) -> nn.Module:
...
ParallelizeFunction: TypeAlias = Callable[..., nn.Module]
PipeliningFunction: TypeAlias = Callable[
..., tuple[_PipelineSchedule, list[nn.Module], bool, bool]
]
DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader]
TokenizerBuilder: TypeAlias = Callable[..., Tokenizer]
MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
OptimizersBuilder: TypeAlias = Callable[
[list[nn.Module], JobConfig, FTManager], OptimizersContainer
]
LRSchedulersBuilder: TypeAlias = Callable[
[OptimizersContainer, JobConfig], LRSchedulersContainer
]
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
@dataclass
class TrainSpec:
name: str
cls: type[nn.Module]
config: Mapping[str, BaseModelArgs]
parallelize_fn: ParallelizeFunction
pipelining_fn: PipeliningFunction | None
build_optimizers_fn: OptimizersBuilder
build_lr_schedulers_fn: LRSchedulersBuilder
build_dataloader_fn: DataLoaderBuilder
build_tokenizer_fn: TokenizerBuilder | None
build_loss_fn: LossFunctionBuilder
build_metrics_processor_fn: MetricsProcessorBuilder | None = None
_train_specs = {}
def register_train_spec(train_spec: TrainSpec) -> None:
global _train_specs
if train_spec.name in _train_specs:
raise ValueError(f"Model {train_spec.name} is already registered.")
_train_specs[train_spec.name] = train_spec
def get_train_spec(name: str) -> TrainSpec:
global _train_specs
if name not in _train_specs:
raise ValueError(f"Model {name} is not registered.")
return _train_specs[name]
def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None:
global _train_specs
for name, train_spec in _train_specs.items():
_train_specs[name] = func(train_spec)