Skip to content

Commit ae425b8

Browse files
committed
Update
[ghstack-poisoned]
2 parents 955e656 + dc089a5 commit ae425b8

13 files changed

+133
-122
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ __pycache__
55
build
66
outputs
77
dist/*
8+
.vscode
89

910
# data
1011
data

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Our guiding principles when building `torchtitan`:
5252
- estimate FSDP/HSDP memory usage without materializing the model
5353
- run distributed inference with Tensor Parallel
5454

55-
We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
55+
We report [performance](docs/performance.md) on up to 512 GPUs, and verify [loss converging](docs/converging.md) correctness of various techniques.
5656

5757
### Dive into the code
5858

-110 KB
Binary file not shown.

assets/images/llama3_loss_curves.png

-247 KB
Binary file not shown.

assets/images/loss_curves.png

189 KB
Loading

docs/converging.md

+13-15
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,25 @@ If the technique is not a parallelism
3535

3636
## Example
3737

38-
Setup
38+
This is a series of loss-converging tests covering both parallelisms and training optimizations.
39+
Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `torchtitan`.
40+
41+
### Setup
3942
- Base config: [train_configs/llama3_8b.toml](../train_configs/llama3_8b.toml)
40-
- `training.batch_size = 4`
41-
- a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
43+
- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
44+
- `training.data_parallel_shard_degree = 8`, resulting in global batch size 32
4245
- `training.steps = 3000`, `training.warmup_steps = 600`
4346

44-
Remarks
45-
- This is an example series of loss-converging tests on parallelisms, not including training optimizations.
46-
- The default global batch size in the toml config is 64 (DP 64, local batch size 1). Given hardware resource availability, one can consider using a smaller (or larger) N. This will result in smaller (or larger, respectively) global batch size, which would possibly necessitate a smaller (or larger, respectively) learning rate (`optimizer.lr`) to keep training stability.
47-
48-
49-
| Parallelisms Dimension <br> (N = 16 by default) | Setup | Remarks |
47+
| Parallelism | Techniques | Remarks |
5048
| ----- | ----- | ----- |
51-
| 1D (N GPUs) | FSDP N | the 1D control set |
52-
| 3D (4N GPUs) | FSDP N, TP2, PP 2 | 3D test set |
53-
| 4D (8N GPUs) | FSDP N, TP 2, CP 2, PP 2 | 4D test set |
54-
| 2D (MN GPUs) <br> e.g. M=8 | FSDP N, CP M | to verify CP with a larger degree |
55-
49+
| FSDP 8 | default | 1D control set |
50+
| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP | 3D test set |
51+
| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set |
52+
| FSDP 8, CP 8 | default | to verify CP with a larger degree |
5653

5754
### Test results
58-
(TBA)
55+
![image](../assets/images/loss_curves.png)
56+
5957

6058
[^1]: Model initialization in a sharded setting can hardly match that in a single-device setting (or a differently sharded setting), because each time a random operator is called, the underlying RNG state offset is advanced by a quantized amount, often not aligned with the amount of randomness needed, thus “wasting” different amount of randomness on differently sharded settings.
6159

docs/performance.md

+57-24
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,72 @@
1-
To demonstrate the effectiveness of PyTorch distributed training techniques used in torchtitan, we report both the infra metrics and loss curves of Llama 3 (8B and 70B) training on 64 A100 (80GB memory) GPUs and Llama 3.1 (405B) on 128 H100 (94GB memory).
2-
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training. (We only report 2D for 405B)
1+
We demonstrate the effectiveness of elastic distributed training using torchtitan, via experiments on Llama 3.1 8B, 70B, and 405B models, from 1D parallelism to 4D parallelism, at the scale from 8 GPUs to 512 GPUs.
32

3+
We ran our performance benchmarks on the [Grand Teton platform](https://engineering.fb.com/2022/10/18/open-source/ocp-summit-2022-grand-teton/), where
4+
- Each host has 8 NVIDIA H100 GPUs fully connected with NVLink.
5+
- Each H100 GPU is equipped with 96GB HBM2e with 2.4 TB/sec peak memory bandwidth.
6+
- Hosts are inter-connected with backend RDMA network with 400 Gb/s per GPU.
7+
- We used the default 500W power limit, although tuning it up to 700W TDP can potentially provide further speedups.
48

5-
## Llama 3.1 performance numbers
9+
We note that, throughout our experimentation, memory readings are stable across the whole training process[^1], whereas throughput numbers (TPS/GPU) are calculated and logged every 10 iterations, and always read at the (arbitrarily determined) 90th iteration.
610

7-
Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on the 405B model released in [Llama 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). The way we compute WPS and MFU can be found in `train.py`. Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D parallelism (FSDP2), so we only tested 2D parallelism (FSDP2 + Tensor Parallel).
11+
We do not report Model FLOPS Utilization (MFU) because when Float8 is enabled (on `nn.Linear` modules), both BFLOAT16 Tensor Core and FP8 Tensor Core are involved in model training, but they have different peak FLOPS and the definition of MFU under such scenario is not well-defined. We note that the 1D Llama 3.1 8B model training on 8 or 128 H100 GPUs without Float8 achieves 33% to 39% MFU[^2] (with or without torch.compile, respectively).
812

9-
| Model size | Batch size | Activation checkpointing | WPS | MFU | Optimizations |
10-
| ----- | ----- | ----- | ----- | ----- | ----- |
11-
| 405B | 2 | full | 109 | 29.0%[^1] | None
12-
| 405B | 2 | full | 177 | 23.46%[^2] | Float8
13-
| 405B | 2 | full | 185 | 24.52% | Float8 + Async TP
13+
**Table 1** 1D Parallelism (FSDP). Llama 3.1 8B model. 8 GPUs. Local batch size 2, global batch size 16. Selective activation checkpointing.
1414

15-
Here, we use local batch size 2 (global batch size = local batch size 2 * number of FSDP ranks 16 = 32).
15+
| Techniques | TPS/GPU | Memory(GiB) |
16+
| ----- | ----: | ----: |
17+
| FSDP | 5,762 | 82.4 |
18+
| FSDP + torch.compile | 6,667 | 77.0 |
19+
| FSDP + torch.compile + Float8 | 8,532 | 76.8 |
1620

17-
Next, we show the loss curves, all models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 32. We have to use full AC to save memory usage. The results are shown in the picture (a TensorBoard screenshot) below.
21+
**Table 2** FSDP + CP + torch.compile + Float8. Llama 3.1 8B model. 8 GPUs. Local batch size 1. Full activation checkpointing.
1822

19-
![image](../assets/images/llama3_1_405B_loss_curves.png)
23+
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) |
24+
| ----- | ----: | ----: | ----: |
25+
| FSDP 8, CP 1 | 32768 | 3,890 | 83.9 |
26+
| FSDP 4, CP 2 | 65536 | 2,540 | 84.2 |
27+
| FSDP 2, CP 4 | 131072 | 1,071 | 84.0 |
28+
| FSDP 1, CP 8 | 262144 | 548 | 84.5 |
2029

21-
## Llama 3 performance numbers
30+
**Table 3** 1D Parallelism (FSDP). Llama 3.1 8B model. 128 GPUs. Local batch size 2, global batch size 256. Selective activation checkpointing.
2231

23-
Below are the WPS and MFU results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs.
32+
| Techniques | TPS/GPU | Memory(GiB) |
33+
| ----- | ----: | ----: |
34+
| FSDP | 5,605 | 67.0 |
35+
| FSDP + torch.compile | 6,514 | 62.0 |
36+
| FSDP + torch.compile + Float8 | 8,380 | 61.8 |
2437

25-
| Model size | Batch size | Activation checkpointing | WPS | MFU |
26-
| ----- | ----- | ----- | ----- | ----- |
27-
| 8B | 1 | selective layer | 2904 | 56.8% |
28-
| 8B | 1 | selective op | 2973 | 58.2% |
29-
| 70B | 1 | full | 331 | 51.7% |
38+
**Table 4** 2D parallelism (FSDP + TP) + torch.compile + Float8. Llama 3.1 70B model. 256 GPUs (FSDP 32, TP 8). Local batch size 16, global batch size 512. Full activation checkpointing.
3039

31-
We use local batch size 1 (global batch size = local batch size 1 * number of FSDP ranks 64 = 64), because it mimics the small local batch size in large scaled training, and moreoever allows us to compare 1D (FSDP) and 2D (FSDP + TP) training under the same global batch size on both 8B and 70B Llama 3 models, without the out-of-memory (OOM) issue.
40+
| Techniques | TPS/GPU | Memory(GiB) |
41+
| ----- | ----: | ----: |
42+
| 2D | 829 | 71.9 |
43+
| 2D + AsyncTP | 876 | 67.6 |
3244

33-
Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained the same way as mentioned above with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.
45+
**Table 5** 3D parallelism (FSDP + TP + PP) + torch.compile + Float8 + AsyncTP. Llama 3.1 405B model. 512 GPUs (FSDP 8, TP 8, PP8). Local batch size 32, global batch size 256. Full activation checkpointing.
3446

35-
![image](../assets/images/llama3_loss_curves.png)
47+
| Schedule | TPS/GPU | Memory(GiB) |
48+
| ----- | ----: | ----: |
49+
| 1F1B | 100 | 82.5 |
50+
| Interleaved 1F1B | 128 | 72.7 |
3651

37-
[^1]: We used HBM2e based lower TDP SXM H100(95GB) for our test, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly.
52+
**Table 6** 4D parallelism (FSDP + TP + PP + CP) + torch.compile + Float8 + AsyncTP + 1F1B. Llama 3.1 405B model. 512 GPUs (TP 8, PP8). Local batch size 8. Full activation checkpointing.
3853

39-
[^2]: Since for Float8, we are not converting all the matmuls to Float8 because our fused attention implementation is not done in Float8, so this number is lower than expected.
54+
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) |
55+
| ----- | ----: | ----: | ----: |
56+
| FSDP 8, CP 1 | 32768 | 76 | 75.3 |
57+
| FSDP 4, CP 2 | 65536 | 47 | 75.9 |
58+
| FSDP 2, CP 4 | 131072 | 31 | 77.1 |
59+
| FSDP 1, CP 8 | 262144 | 16 | 84.9 |
60+
61+
62+
#### Versions used for performance testing
63+
| repo | commit | date |
64+
| --- | --- | --- |
65+
| torch | [1963fc8](https://github.com/pytorch/pytorch/commit/1963fc83a1c32e162162e2414f78b043f0674bae) | 2024/12/23 |
66+
| torchao | [eab345c](https://github.com/pytorch/ao/commit/eab345c2268a7506355d506ebfc27b5d28e5e7d0) | 2024/12/23 |
67+
| torchtitan | [9dec370](https://github.com/pytorch/torchtitan/commit/9dec370ad26b5f8e9a7333a0e36165018262644b) | 2024/12/26 |
68+
69+
70+
[^1]: Different PP ranks can have different peak memory usages. We take the maximum across all GPUs.
71+
72+
[^2]: In our test we used HBM2e-based SXM H100 with lower TDP, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly.

scripts/estimate/estimation.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def loss_fn(pred, labels):
116116
model_config.vocab_size = tokenizer.n_words
117117
model_config.max_seq_len = job_config.training.seq_len
118118

119-
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():
119+
with (
120+
FakeTensorMode()
121+
if not job_config.memory_estimation.disable_fake_mode
122+
else contextlib.nullcontext()
123+
):
120124

121125
logger.info(
122126
f"Building {model_name} {job_config.model.flavor} with {model_config}"
@@ -174,8 +178,6 @@ def loss_fn(pred, labels):
174178
torch.nn.utils.clip_grad_norm_(
175179
model.parameters(), job_config.training.max_norm, foreach=True
176180
)
177-
# sync float8 amaxes and scales
178-
float8_handler.sync_float8_amax_and_scale_history(model)
179181
# optimizer step
180182
optimizers.step()
181183
lr_schedulers.step()

torchtitan/checkpoint.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,41 @@ def __init__(
156156
if not self.enable_checkpoint and self.ft_manager is None:
157157
return
158158

159+
<<<<<<< HEAD
160+
1. even for simple PP schedules, there is a separate optimizer each PP rank.
161+
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
162+
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
163+
When saving, these collide and one of them is lost. Then when reloading, only one stage can
164+
restore its optimizer states, others will error.
165+
166+
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
167+
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
168+
169+
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
170+
requiring us to reason about multiple 'optim' objects locally.
171+
172+
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
173+
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
174+
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
175+
support described in (1).
176+
177+
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
178+
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
179+
optimizers do, so it's hard to write a generic 'flattener' utility.
180+
181+
TODO: This is currently unsolved and needs a fix.
182+
"""
183+
self.states = states
184+
185+
self.states.update(
186+
{
187+
"model": ModelWrapper(model_parts),
188+
"optimizer": optimizers,
189+
"dataloader": dataloader,
190+
"lr_scheduler": lr_schedulers,
191+
}
192+
)
193+
=======
159194
self._initialize_states(
160195
states, dataloader, model_parts, optimizers, lr_schedulers
161196
)
@@ -166,6 +201,7 @@ def __init__(
166201
self.staging_id = None
167202
self.cpu_offload_state_dict = None
168203
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
204+
>>>>>>> 3430d99 ([WIP][RFC] TorchFT integration)
169205
170206
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
171207
self.interval_type = (
@@ -177,7 +213,9 @@ def __init__(
177213
self.begin_time = 0
178214
self.time_sync_work = None
179215
self.time_sync_result = None
180-
self.pg = dist.new_group(backend="gloo")
216+
async_mode = ckpt_config.async_mode.lower()
217+
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
218+
self.pg = dist.new_group(backend="gloo")
181219
182220
self.keep_latest_k = ckpt_config.keep_latest_k
183221
self.model_weights_only = ckpt_config.model_weights_only

torchtitan/config_manager.py

-19
Original file line numberDiff line numberDiff line change
@@ -548,25 +548,6 @@ def __init__(self):
548548
action="store_true",
549549
help="Whether precompute float8 scales dynamically for FSDP",
550550
)
551-
self.parser.add_argument(
552-
"--float8.scaling_type_input",
553-
type=str,
554-
default="dynamic",
555-
help="float8 scaling for input, dynamic (default) or delayed",
556-
choices=["dynamic", "delayed"],
557-
)
558-
self.parser.add_argument(
559-
"--float8.scaling_type_weight",
560-
type=str,
561-
default="dynamic",
562-
help="float8 scaling for input, dynamic (default) or delayed",
563-
)
564-
self.parser.add_argument(
565-
"--float8.scaling_type_grad_output",
566-
type=str,
567-
default="dynamic",
568-
help="float8 scaling for input, dynamic (default) or delayed",
569-
)
570551

571552
# communications library settings
572553
self.parser.add_argument(

torchtitan/float8.py

+1-44
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4141
)
4242
return
4343
try:
44-
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
44+
from torchao.float8 import Float8LinearConfig
4545
except ImportError as e:
4646
raise ImportError(
4747
"torchao is not installed. Please install it to use float8 linear layers."
@@ -52,14 +52,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5252
parallel_dims.dp_shard_enabled
5353
and float8_config.enable_fsdp_float8_all_gather
5454
)
55-
scaling_type_input = ScalingType(float8_config.scaling_type_input)
56-
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
57-
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
5855
self.config = Float8LinearConfig(
5956
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
60-
cast_config_input=CastConfig(scaling_type=scaling_type_input),
61-
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
62-
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
6357
)
6458

6559
self.enabled = True
@@ -70,15 +64,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
7064
and float8_config.precompute_float8_dynamic_scale_for_fsdp
7165
)
7266

73-
# for sync_float8_amax_and_scale_history
74-
self.delayed_scaling = (
75-
scaling_type_input is ScalingType.DELAYED
76-
or scaling_type_weight is ScalingType.DELAYED
77-
or scaling_type_grad_output is ScalingType.DELAYED
78-
)
79-
self._sync_float8_amax_and_scale_history = None
80-
self.compile = job_config.training.compile
81-
8267
logger.info("Float8 training active")
8368

8469
def convert_to_float8_training(self, model: nn.Module):
@@ -117,31 +102,3 @@ def precompute_float8_dynamic_scale_for_fsdp(
117102
models = [model] if isinstance(model, nn.Module) else model
118103
for m in models:
119104
precompute_float8_dynamic_scale_for_fsdp(m)
120-
121-
def sync_float8_amax_and_scale_history(
122-
self, model: Union[nn.Module, List[nn.Module]]
123-
):
124-
if not self.enabled:
125-
return
126-
127-
if not self.delayed_scaling:
128-
return
129-
130-
from torchao.float8 import sync_float8_amax_and_scale_history
131-
132-
# TODO(vkuzo): see if precalculating the modules to sync over is going to
133-
# meaningfully help performance
134-
135-
if self._sync_float8_amax_and_scale_history is None:
136-
if self.compile:
137-
self._sync_float8_amax_and_scale_history = torch.compile(
138-
sync_float8_amax_and_scale_history
139-
)
140-
else:
141-
self._sync_float8_amax_and_scale_history = (
142-
sync_float8_amax_and_scale_history
143-
)
144-
145-
models = [model] if isinstance(model, nn.Module) else model
146-
for m in models:
147-
self._sync_float8_amax_and_scale_history(m)

0 commit comments

Comments
 (0)