Skip to content

Commit 159f03b

Browse files
committed
Introducing a generic ModelHandler interface.
This model handler interface should cover most cases in quantization, fused layer optimization, ...
1 parent 690f299 commit 159f03b

File tree

10 files changed

+127
-26
lines changed

10 files changed

+127
-26
lines changed

docs/float8.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Launch training job with the following command (or alternatively set configs in
99
```
1010
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
1111
```
12-
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
12+
<!-- * `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. -->
1313
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
1414
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
1515

torchtitan/config_manager.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ def __init__(self):
182182
default="./torchtitan/datasets/tokenizer/tokenizer.model",
183183
help="Tokenizer path",
184184
)
185+
self.parser.add_argument(
186+
"--model.handlers",
187+
type=str,
188+
default="",
189+
help="Comma separated list of handlers to apply to the model (e.g. 'float8')",
190+
)
185191

186192
# optimizer configs
187193
self.parser.add_argument(
@@ -529,15 +535,15 @@ def __init__(self):
529535
)
530536

531537
# float8 configs
532-
self.parser.add_argument(
533-
"--float8.enable_float8_linear",
534-
action="store_true",
535-
help="""
536-
If true, swaps `torch.nn.Linear` with `Float8Linear`.
537-
This feature requires you to install 'torchao' which can be found
538-
here: https://github.com/pytorch/ao
539-
""",
540-
)
538+
# self.parser.add_argument(
539+
# "--float8.enable_float8_linear",
540+
# action="store_true",
541+
# help="""
542+
# If true, swaps `torch.nn.Linear` with `Float8Linear`.
543+
# This feature requires you to install 'torchao' which can be found
544+
# here: https://github.com/pytorch/ao
545+
# """,
546+
# )
541547
self.parser.add_argument(
542548
"--float8.enable_fsdp_float8_all_gather",
543549
action="store_true",

torchtitan/float8.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from torchtitan.config_manager import JobConfig
2222
from torchtitan.logging import logger
23+
from torchtitan.model_handler import ModelHandler, register_model_handler
2324
from torchtitan.parallelisms import ParallelDims
2425

2526

@@ -28,13 +29,11 @@ def _is_sm89_or_later():
2829
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2930

3031

31-
class Float8Handler:
32+
class Float8Handler(ModelHandler):
3233
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3334
self.enabled = False
3435

3536
float8_config = job_config.float8
36-
if not float8_config.enable_float8_linear:
37-
return
3837
if not _is_sm89_or_later():
3938
logger.warning(
4039
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
@@ -66,6 +65,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6665

6766
logger.info("Float8 training active")
6867

68+
def convert(self, model: nn.Module):
69+
return self.convert_to_float8_training(model)
70+
71+
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
72+
return self.precompute_float8_dynamic_scale_for_fsdp(model)
73+
6974
def convert_to_float8_training(self, model: nn.Module):
7075
"""
7176
This function converts the linear layers of `model` to `Float8Linear`.
@@ -102,3 +107,6 @@ def precompute_float8_dynamic_scale_for_fsdp(
102107
models = [model] if isinstance(model, nn.Module) else model
103108
for m in models:
104109
precompute_float8_dynamic_scale_for_fsdp(m)
110+
111+
112+
register_model_handler(Float8Handler, "float8")

torchtitan/model_handler.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict, List, Protocol, Union
7+
8+
import torch.nn as nn
9+
10+
from torchtitan.config_manager import JobConfig
11+
from torchtitan.parallelisms import ParallelDims
12+
13+
14+
class ModelHandler(Protocol):
15+
"""General model handler interface.
16+
17+
A model handler is applying a modification to PyTorch model.
18+
Typical use cases are:
19+
- Quantization: using QAT, FP8, ... specialized linear layers;
20+
- Fused optimized layers (e.g. flash-attention, norms, ...)
21+
"""
22+
23+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
24+
...
25+
26+
def convert(self, model: nn.Module):
27+
"""Inplace convertion of the model."""
28+
...
29+
30+
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
31+
"""Post-optimizer (optional) hook (e.g. compute weights statistics)."""
32+
...
33+
34+
35+
_registry_model_handler_cls: Dict[str, type[ModelHandler]] = {}
36+
"""Registry of model handler classes.
37+
"""
38+
39+
40+
def register_model_handler(handler_cls: type[ModelHandler], name: str):
41+
"""Register a model handler class.
42+
43+
A registered model handler can be applied on any TorchTitan model
44+
using the `model.handlers` config parameter.
45+
"""
46+
assert (
47+
name not in _registry_model_handler_cls
48+
), f"A TorchTitan model handler '{name}' is already registered."
49+
_registry_model_handler_cls[name] = handler_cls
50+
51+
52+
class ModelHandlersContainer(ModelHandler):
53+
"""Model handlers sequential container.
54+
55+
The class build the sequence of model handlers defined in `model.handlers`
56+
job config, and apply them to the model sequentially.
57+
"""
58+
59+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
60+
handler_names = parse_model_handlers(job_config)
61+
handler_classes = [_registry_model_handler_cls[name] for name in handler_names]
62+
self.handlers = [
63+
mh_cls(job_config, parallel_dims) for mh_cls in handler_classes
64+
]
65+
66+
def convert(self, model: nn.Module):
67+
for mh in self.handlers:
68+
mh.convert(model)
69+
70+
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
71+
for mh in self.handlers:
72+
mh.post_optimizer_hook(model)
73+
74+
75+
def parse_model_handlers(job_config: JobConfig) -> List[str]:
76+
"""Parse the list of model handlers to apply."""
77+
handler_names = [v.strip() for v in job_config.model.handlers.split(",")]
78+
handler_names = [v for v in handler_names if len(v) > 0]
79+
return handler_names
80+
81+
82+
def build_model_handlers_container(
83+
job_config: JobConfig, parallel_dims: ParallelDims
84+
) -> ModelHandlersContainer:
85+
"""Build the collection of model handlers to apply to the model."""
86+
return ModelHandlersContainer(job_config, parallel_dims)

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
3636
from torchtitan.parallelisms.parallel_dims import ParallelDims
37-
37+
from torchtitan.model_handler import parse_model_handlers
3838

3939
def parallelize_llama(
4040
model: nn.Module,
@@ -56,11 +56,12 @@ def parallelize_llama(
5656
and not job_config.training.compile
5757
):
5858
raise RuntimeError("Async TP requires --training.compile")
59+
enable_float8 = "float8" in parse_model_handlers(job_config)
5960
apply_tp(
6061
model,
6162
world_mesh["tp"],
6263
loss_parallel=parallel_dims.loss_parallel_enabled,
63-
enable_float8=job_config.float8.enable_float8_linear,
64+
enable_float8=enable_float8,
6465
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
6566
)
6667

train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from torchtitan.checkpoint import CheckpointManager, TrainState
1717
from torchtitan.config_manager import JobConfig
1818
from torchtitan.datasets import build_hf_data_loader, build_tokenizer
19-
from torchtitan.float8 import Float8Handler
2019
from torchtitan.logging import init_logger, logger
2120
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
21+
from torchtitan.model_handler import build_model_handlers_container
2222
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2323
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2424
from torchtitan.parallelisms import (
@@ -110,10 +110,9 @@ def main(job_config: JobConfig):
110110
with torch.device("meta"):
111111
model = model_cls.from_model_args(model_config)
112112

113-
# a no-op hander if float8 is not enabled
114-
float8_handler = Float8Handler(job_config, parallel_dims)
115-
# swap to Float8Linear based on float8 configs
116-
float8_handler.convert_to_float8_training(model)
113+
# Build the collection of model handlers. No-op if `model.handlers` empty
114+
model_handlers = build_model_handlers_container(job_config, parallel_dims)
115+
model_handlers.convert(model)
117116

118117
# log model size
119118
model_param_count = utils.get_num_params(model)
@@ -326,9 +325,10 @@ def loss_fn(pred, labels):
326325
optimizers.step()
327326
lr_schedulers.step()
328327

329-
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
328+
# Post-optimizer model handlers hook.
329+
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
330330
# it issues a single all-reduce for all parameters at once for better performance
331-
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)
331+
model_handlers.post_optimizer_hook(model_parts)
332332

333333
# log metrics
334334
if (

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ flavor = "debugmodel"
2626
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
2727
# test tokenizer.model, for debug purpose only
2828
tokenizer_path = "./tests/assets/test_tiktoken.model"
29+
handlers = ""
2930

3031
[optimizer]
3132
name = "AdamW"
@@ -62,4 +63,3 @@ mode = 'selective' # ['none', 'selective', 'full']
6263
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6364

6465
[float8]
65-
enable_float8_linear = false

train_configs/llama3_405b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ name = "llama3"
2020
flavor = "405B"
2121
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
2222
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
23+
handlers = "float8"
2324

2425
[optimizer]
2526
name = "AdamW"
@@ -55,6 +56,5 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5556
mode = 'full' # ['none', 'selective', 'full']
5657

5758
[float8]
58-
enable_float8_linear = true
5959
enable_fsdp_float8_all_gather = true
6060
precompute_float8_dynamic_scale_for_fsdp = true

train_configs/llama3_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ name = "llama3"
2020
flavor = "70B"
2121
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
2222
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
23+
handlers = ""
2324

2425
[optimizer]
2526
name = "AdamW"
@@ -54,4 +55,3 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5455
mode = 'full'
5556

5657
[float8]
57-
enable_float8_linear = false

train_configs/llama3_8b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ name = "llama3"
2020
flavor = "8B"
2121
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
2222
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
23+
handlers = ""
2324

2425
[optimizer]
2526
name = "AdamW"
@@ -55,4 +56,3 @@ mode = 'selective' # ['none', 'selective', 'full']
5556
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
5657

5758
[float8]
58-
enable_float8_linear = false

0 commit comments

Comments
 (0)