Skip to content

Commit d4c86e3

Browse files
committed
update performance and loss converging results
ghstack-source-id: c2f3041 Pull Request resolved: #800
1 parent 3278a52 commit d4c86e3

File tree

6 files changed

+71
-40
lines changed

6 files changed

+71
-40
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Our guiding principles when building `torchtitan`:
5252
- estimate FSDP/HSDP memory usage without materializing the model
5353
- run distributed inference with Tensor Parallel
5454

55-
We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
55+
We report [performance](docs/performance.md) on up to 512 GPUs, and verify [loss converging](docs/converging.md) correctness of various techniques.
5656

5757
### Dive into the code
5858

-110 KB
Binary file not shown.
-247 KB
Binary file not shown.

assets/images/loss_curves.png

189 KB
Loading

docs/converging.md

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,25 @@ If the technique is not a parallelism
3535

3636
## Example
3737

38-
Setup
38+
This is a series of loss-converging tests covering both parallelisms and training optimizations.
39+
Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `torchtitan`.
40+
41+
### Setup
3942
- Base config: [train_configs/llama3_8b.toml](../train_configs/llama3_8b.toml)
40-
- `training.batch_size = 4`
41-
- a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
43+
- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
44+
- `training.data_parallel_shard_degree = 8`, resulting in global batch size 32
4245
- `training.steps = 3000`, `training.warmup_steps = 600`
4346

44-
Remarks
45-
- This is an example series of loss-converging tests on parallelisms, not including training optimizations.
46-
- The default global batch size in the toml config is 64 (DP 64, local batch size 1). Given hardware resource availability, one can consider using a smaller (or larger) N. This will result in smaller (or larger, respectively) global batch size, which would possibly necessitate a smaller (or larger, respectively) learning rate (`optimizer.lr`) to keep training stability.
47-
48-
49-
| Parallelisms Dimension <br> (N = 16 by default) | Setup | Remarks |
47+
| Parallelism | Techniques | Remarks |
5048
| ----- | ----- | ----- |
51-
| 1D (N GPUs) | FSDP N | the 1D control set |
52-
| 3D (4N GPUs) | FSDP N, TP2, PP 2 | 3D test set |
53-
| 4D (8N GPUs) | FSDP N, TP 2, CP 2, PP 2 | 4D test set |
54-
| 2D (MN GPUs) <br> e.g. M=8 | FSDP N, CP M | to verify CP with a larger degree |
55-
49+
| FSDP 8 | default | 1D control set |
50+
| FSDP 8, TP 2, PP 2 | torch.compile, Float8, async TP | 3D test set |
51+
| FSDP 8, TP 2, CP 2, PP 2 | torch.compile, Float8, async TP, Interleaved 1F1B | 4D test set |
52+
| FSDP 8, CP 8 | default | to verify CP with a larger degree |
5653

5754
### Test results
58-
(TBA)
55+
![image](../assets/images/loss_curves.png)
56+
5957

6058
[^1]: Model initialization in a sharded setting can hardly match that in a single-device setting (or a differently sharded setting), because each time a random operator is called, the underlying RNG state offset is advanced by a quantized amount, often not aligned with the amount of randomness needed, thus “wasting” different amount of randomness on differently sharded settings.
6159

docs/performance.md

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,72 @@
1-
To demonstrate the effectiveness of PyTorch distributed training techniques used in torchtitan, we report both the infra metrics and loss curves of Llama 3 (8B and 70B) training on 64 A100 (80GB memory) GPUs and Llama 3.1 (405B) on 128 H100 (94GB memory).
2-
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training. (We only report 2D for 405B)
1+
We demonstrate the effectiveness of elastic distributed training using torchtitan, via experiments on Llama 3.1 8B, 70B, and 405B models, from 1D parallelism to 4D parallelism, at the scale from 8 GPUs to 512 GPUs.
32

3+
We ran our performance benchmarks on the [Grand Teton platform](https://engineering.fb.com/2022/10/18/open-source/ocp-summit-2022-grand-teton/), where
4+
- Each host has 8 NVIDIA H100 GPUs fully connected with NVLink.
5+
- Each H100 GPU is equipped with 96GB HBM2e with 2.4 TB/sec peak memory bandwidth.
6+
- Hosts are inter-connected with backend RDMA network with 400 Gb/s per GPU.
7+
- We used the default 500W power limit, although tuning it up to 700W TDP can potentially provide further speedups.
48

5-
## Llama 3.1 performance numbers
9+
We note that, throughout our experimentation, memory readings are stable across the whole training process[^1], whereas throughput numbers (TPS/GPU) are calculated and logged every 10 iterations, and always read at the (arbitrarily determined) 90th iteration.
610

7-
Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on the 405B model released in [Llama 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). The way we compute WPS and MFU can be found in `train.py`. Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D parallelism (FSDP2), so we only tested 2D parallelism (FSDP2 + Tensor Parallel).
11+
We do not report Model FLOPS Utilization (MFU) because when Float8 is enabled (on `nn.Linear` modules), both BFLOAT16 Tensor Core and FP8 Tensor Core are involved in model training, but they have different peak FLOPS and the definition of MFU under such scenario is not well-defined. We note that the 1D Llama 3.1 8B model training on 8 or 128 H100 GPUs without Float8 achieves 33% to 39% MFU[^2] (with or without torch.compile, respectively).
812

9-
| Model size | Batch size | Activation checkpointing | WPS | MFU | Optimizations |
10-
| ----- | ----- | ----- | ----- | ----- | ----- |
11-
| 405B | 2 | full | 109 | 29.0%[^1] | None
12-
| 405B | 2 | full | 177 | 23.46%[^2] | Float8
13-
| 405B | 2 | full | 185 | 24.52% | Float8 + Async TP
13+
**Table 1** 1D Parallelism (FSDP). Llama 3.1 8B model. 8 GPUs. Local batch size 2, global batch size 16. Selective activation checkpointing.
1414

15-
Here, we use local batch size 2 (global batch size = local batch size 2 * number of FSDP ranks 16 = 32).
15+
| Techniques | TPS/GPU | Memory(GiB) |
16+
| ----- | ----: | ----: |
17+
| FSDP | 5,762 | 82.4 |
18+
| FSDP + torch.compile | 6,667 | 77.0 |
19+
| FSDP + torch.compile + Float8 | 8,532 | 76.8 |
1620

17-
Next, we show the loss curves, all models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 32. We have to use full AC to save memory usage. The results are shown in the picture (a TensorBoard screenshot) below.
21+
**Table 2** FSDP + CP + torch.compile + Float8. Llama 3.1 8B model. 8 GPUs. Local batch size 1. Full activation checkpointing.
1822

19-
![image](../assets/images/llama3_1_405B_loss_curves.png)
23+
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) |
24+
| ----- | ----: | ----: | ----: |
25+
| FSDP 8, CP 1 | 32768 | 3,890 | 83.9 |
26+
| FSDP 4, CP 2 | 65536 | 2,540 | 84.2 |
27+
| FSDP 2, CP 4 | 131072 | 1,071 | 84.0 |
28+
| FSDP 1, CP 8 | 262144 | 548 | 84.5 |
2029

21-
## Llama 3 performance numbers
30+
**Table 3** 1D Parallelism (FSDP). Llama 3.1 8B model. 128 GPUs. Local batch size 2, global batch size 256. Selective activation checkpointing.
2231

23-
Below are the WPS and MFU results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs.
32+
| Techniques | TPS/GPU | Memory(GiB) |
33+
| ----- | ----: | ----: |
34+
| FSDP | 5,605 | 67.0 |
35+
| FSDP + torch.compile | 6,514 | 62.0 |
36+
| FSDP + torch.compile + Float8 | 8,380 | 61.8 |
2437

25-
| Model size | Batch size | Activation checkpointing | WPS | MFU |
26-
| ----- | ----- | ----- | ----- | ----- |
27-
| 8B | 1 | selective layer | 2904 | 56.8% |
28-
| 8B | 1 | selective op | 2973 | 58.2% |
29-
| 70B | 1 | full | 331 | 51.7% |
38+
**Table 4** 2D parallelism (FSDP + TP) + torch.compile + Float8. Llama 3.1 70B model. 256 GPUs (FSDP 32, TP 8). Local batch size 16, global batch size 512. Full activation checkpointing.
3039

31-
We use local batch size 1 (global batch size = local batch size 1 * number of FSDP ranks 64 = 64), because it mimics the small local batch size in large scaled training, and moreoever allows us to compare 1D (FSDP) and 2D (FSDP + TP) training under the same global batch size on both 8B and 70B Llama 3 models, without the out-of-memory (OOM) issue.
40+
| Techniques | TPS/GPU | Memory(GiB) |
41+
| ----- | ----: | ----: |
42+
| 2D | 829 | 71.9 |
43+
| 2D + AsyncTP | 876 | 67.6 |
3244

33-
Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained the same way as mentioned above with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.
45+
**Table 5** 3D parallelism (FSDP + TP + PP) + torch.compile + Float8 + AsyncTP. Llama 3.1 405B model. 512 GPUs (FSDP 8, TP 8, PP8). Local batch size 32, global batch size 256. Full activation checkpointing.
3446

35-
![image](../assets/images/llama3_loss_curves.png)
47+
| Schedule | TPS/GPU | Memory(GiB) |
48+
| ----- | ----: | ----: |
49+
| 1F1B | 100 | 82.5 |
50+
| Interleaved 1F1B | 128 | 72.7 |
3651

37-
[^1]: We used HBM2e based lower TDP SXM H100(95GB) for our test, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly.
52+
**Table 6** 4D parallelism (FSDP + TP + PP + CP) + torch.compile + Float8 + AsyncTP + 1F1B. Llama 3.1 405B model. 512 GPUs (TP 8, PP8). Local batch size 8. Full activation checkpointing.
3853

39-
[^2]: Since for Float8, we are not converting all the matmuls to Float8 because our fused attention implementation is not done in Float8, so this number is lower than expected.
54+
| Parallelism | Sequence Length | TPS/GPU | Memory(GiB) |
55+
| ----- | ----: | ----: | ----: |
56+
| FSDP 8, CP 1 | 32768 | 76 | 75.3 |
57+
| FSDP 4, CP 2 | 65536 | 47 | 75.9 |
58+
| FSDP 2, CP 4 | 131072 | 31 | 77.1 |
59+
| FSDP 1, CP 8 | 262144 | 16 | 84.9 |
60+
61+
62+
#### Versions used for performance testing
63+
| repo | commit | date |
64+
| --- | --- | --- |
65+
| torch | [1963fc8](https://github.com/pytorch/pytorch/commit/1963fc83a1c32e162162e2414f78b043f0674bae) | 2024/12/23 |
66+
| torchao | [eab345c](https://github.com/pytorch/ao/commit/eab345c2268a7506355d506ebfc27b5d28e5e7d0) | 2024/12/23 |
67+
| torchtitan | [9dec370](https://github.com/pytorch/torchtitan/commit/9dec370ad26b5f8e9a7333a0e36165018262644b) | 2024/12/26 |
68+
69+
70+
[^1]: Different PP ranks can have different peak memory usages. We take the maximum across all GPUs.
71+
72+
[^2]: In our test we used HBM2e-based SXM H100 with lower TDP, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly.

0 commit comments

Comments
 (0)