Skip to content

Commit a5388a4

Browse files
Wong4jnv-kkudrynski
authored andcommitted
[BERT/Paddle] Update base image and integrate cuDNN fused MHA
1 parent 296bb99 commit a5388a4

File tree

11 files changed

+267
-135
lines changed

11 files changed

+267
-135
lines changed

PaddlePaddle/LanguageModeling/BERT/Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:22.12-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.06-py3
22
FROM ${FROM_IMAGE_NAME}
33
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
44

PaddlePaddle/LanguageModeling/BERT/README.md

+39-25
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ Advanced Training:
437437
--use-dynamic-loss-scaling
438438
Enable dynamic loss scaling in AMP training, only applied when --amp is set. (default: False)
439439
--use-pure-fp16 Enable pure FP16 training, only applied when --amp is set. (default: False)
440+
--fuse-mha Enable multihead attention fusion. Require cudnn version >= 8.9.1.
440441
```
441442

442443

@@ -463,6 +464,7 @@ Default arguments are listed below in the order `scripts/run_squad.sh` expects:
463464
- Enable benchmark - The default is `false`.
464465
- Benchmark steps - The default is `100`.
465466
- Benchmark warmup steps - The default is `100`.
467+
- Fuse MHA fusion - The default is `true`
466468

467469
The script saves the final checkpoint to the `/results/bert-large-uncased/squad` folder.
468470

@@ -593,7 +595,8 @@ bash run_pretraining.sh \
593595
<bert_config_file> \
594596
<enable_benchmark> \
595597
<benchmark_steps> \
596-
<benchmark_warmup_steps>
598+
<benchmark_warmup_steps> \
599+
<fuse_mha>
597600
```
598601

599602
Where:
@@ -627,6 +630,7 @@ Where:
627630
- `masking` LDDL supports both static and dynamic masking. Refer to [LDDL's README](https://github.com/NVIDIA/LDDL/blob/main/README.md) for more information.
628631
- `<bert_config_file>` is the path to the bert config file.
629632
- `<enable_benchmark>` a flag to enable benchmark. The train process will warmup for `<benchmark_warmup_steps>` and then measure the throughput of the following `<benchmark_steps>`.
633+
- `<fuse_mha>` a flag to enable cuDNN MHA fusion.
630634

631635
Note that:
632636
- If users follow [Quick Start Guide](#quick-start-guide) to set up container and dataset, there is no need to set any parameters. For example:
@@ -670,6 +674,7 @@ python3 -m paddle.distributed.launch \
670674
--max-predictions-per-seq=20 \
671675
--gradient-merge-steps=32 \
672676
--amp \
677+
--fuse-mha \
673678
--use-dynamic-loss-scaling \
674679
--optimizer=Lamb \
675680
--phase1 \
@@ -769,7 +774,8 @@ bash scripts/run_squad.sh \
769774
<max_steps> \
770775
<enable_benchmark> \
771776
<benchmark_steps> \
772-
<benchmark_warmup_steps>
777+
<benchmark_warmup_steps> \
778+
<fuse_mha>
773779
```
774780

775781
By default, the `mode` argument is set to `train eval`. Refer to the [Quick Start Guide](#quick-start-guide) for explanations of each positional argument.
@@ -812,7 +818,7 @@ bash scripts/run_pretraining.sh \
812818
None \
813819
/path/to/wikipedia/source \
814820
32 128 4 0.9 64 static \
815-
None true 10 10
821+
None true 10 10 true
816822
```
817823

818824
To benchmark the training performance on a specific batch size for SQuAD, refer to [Fine-tuning](#fine-tuning) and turn on the `<benchmark>` flags. An example call to run training for 200 steps (100 steps for warmup and 100 steps to measure), and generate throughput numbers:
@@ -825,7 +831,7 @@ bash scripts/run_squad.sh \
825831
results/checkpoints \
826832
train \
827833
bert_configs/bert-large-uncased.json \
828-
-1 true 100 100
834+
-1 true 100 100 true
829835
```
830836

831837
#### Inference performance benchmark
@@ -841,7 +847,8 @@ bash scripts/run_squad.sh \
841847
<results directory> \
842848
eval \
843849
<BERT config path> \
844-
<max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps>
850+
<max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps> \
851+
<fuse_mha>
845852
```
846853

847854
An example call to run inference and generate throughput numbers:
@@ -854,7 +861,7 @@ bash scripts/run_squad.sh \
854861
results/checkpoints \
855862
eval \
856863
bert_configs/bert-large-uncased.json \
857-
-1 true 100 100
864+
-1 true 100 100 true
858865
```
859866

860867

@@ -870,7 +877,7 @@ Our results were obtained by running the `scripts/run_squad.sh` and `scripts/run
870877

871878
| DGX System | GPUs / Node | Precision | Accumulated Batch size / GPU (Phase 1 and Phase 2) | Accumulation steps (Phase 1 and Phase 2) | Final Loss | Time to train(hours) | Time to train speedup (TF32 to mixed precision) |
872879
|--------------------|-------------|-----------|----------------------------------------------------|------------------------------------------|-------------------|----------------------|-------------------------------------------------|
873-
| 32 x DGX A100 80GB | 8 | AMP | 256 and 128 | 1 and 4 | 1.409 | ~ 1.2 hours | 1.72 |
880+
| 32 x DGX A100 80GB | 8 | AMP | 256 and 128 | 1 and 4 | 1.409 | ~ 1.1 hours | 2.27 |
874881
| 32 x DGX A100 80GB | 8 | TF32 | 128 and 16b | 2 and 8 | 1.421 | ~ 2.5 hours | 1 |
875882

876883

@@ -914,28 +921,28 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
914921

915922
| GPUs | Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
916923
|------|----------------------------------|------------------------------------|-----------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
917-
| 1 | 8192 and 8192 | 64 and 32 | 128 | 307 | 633 | 2.06 | 1.00 | 1.00 |
918-
| 8 | 8192 and 8192 | 64 and 32 | 128 | 2428 | 4990 | 2.06 | 7.91 | 7.88 |
919-
| 1 | 4096 and 4096 | 256 and 128 | 512 | 107 | 219 | 2.05 | 1.00 | 1.00 |
920-
| 8 | 4096 and 4096 | 256 and 128 | 512 | 851 | 1724 | 2.26 | 7.95 | 7.87 |
924+
| 1 | 8192 and 8192 | 64 and 32 | 128 | 307 | 694 | 2.26 | 1.00 | 1.00 |
925+
| 8 | 8192 and 8192 | 64 and 32 | 128 | 2428 | 5541 | 2.28 | 7.91 | 7.98 |
926+
| 1 | 4096 and 4096 | 256 and 128 | 512 | 107 | 264 | 2.47 | 1.00 | 1.00 |
927+
| 8 | 4096 and 4096 | 256 and 128 | 512 | 851 | 2109 | 2.48 | 7.95 | 7.99 |
921928

922929

923930
###### Pre-training NVIDIA DGX A100 (8x A100 80GB) Multi-node Scaling
924931

925932
| Nodes | GPUs / node | Batch size / GPU (TF32 and FP16) | Accumulated Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Mixed Precision Throughput | Mixed Precision Strong Scaling | TF32 Throughput | TF32 Strong Scaling | Speedup (Mixed Precision to TF32) |
926933
|-------|-------------|----------------------------------|------------------------------------|-----------------|----------------------------|--------------------------------|-----------------|---------------------|-----------------------------------|-----|
927-
| 1 | 8 | 126 and 256 | 8192 and 8192 | 64 and 32 | 128 | 4990 | 1 | 2428 | 1 | 2.06 |
928-
| 2 | 8 | 126 and 256 | 4096 and 4096 | 32 and 16 | 128 | 9581 | 1.92 | 4638 | 1.91 | 2.07 |
929-
| 4 | 8 | 126 and 256 | 2048 and 2048 | 16 and 8 | 128 | 19262 | 3.86 | 9445 | 3.89 | 2.04 |
930-
| 8 | 8 | 126 and 256 | 1024 and 1024 | 8 and 4 | 128 | 37526 | 7.52 | 18335 | 7.55 | 2.05 |
931-
| 16 | 8 | 126 and 256 | 512 and 512 | 4 and 2 | 128 | 71156 | 14.26 | 35526 | 14.63 | 2.00 |
932-
| 32 | 8 | 126 and 256 | 256 and 256 | 2 and 1 | 128 | 142087 | 28.47 | 69701 | 28.71 | 2.04 |
933-
| 1 | 8 | 16 and 32 | 4096 and 4096 | 256 and 128 | 512 | 1724 | 1 | 851 | 1 | 2.03 |
934-
| 2 | 8 | 16 and 32 | 2048 and 2048 | 128 and 64 | 512 | 3305 | 1.92 | 1601 | 1.88 | 2.06 |
935-
| 4 | 8 | 16 and 32 | 1024 and 1024 | 64 and 32 | 512 | 6492 | 3.77 | 3240 | 3.81 | 2.00 |
936-
| 8 | 8 | 16 and 32 | 512 and 512 | 32 and 16 | 512 | 12884 | 7.47 | 6329 | 7.44 | 2.04 |
937-
| 16 | 8 | 16 and 32 | 256 and 256 | 16 and 8 | 512 | 25493 | 14.79 | 12273 | 14.42 | 2.08 |
938-
| 32 | 8 | 16 and 32 | 128 and 128 | 8 and 4 | 512 | 49307 | 28.60 | 24047 | 28.26 | 2.05 |
934+
| 1 | 8 | 126 and 256 | 8192 and 8192 | 64 and 32 | 128 | 5541 | 1 | 2428 | 1 | 2.28 |
935+
| 2 | 8 | 126 and 256 | 4096 and 4096 | 32 and 16 | 128 | 10646 | 1.92 | 4638 | 1.91 | 2.29 |
936+
| 4 | 8 | 126 and 256 | 2048 and 2048 | 16 and 8 | 128 | 21389 | 3.86 | 9445 | 3.89 | 2.26 |
937+
| 8 | 8 | 126 and 256 | 1024 and 1024 | 8 and 4 | 128 | 41681 | 7.52 | 18335 | 7.55 | 2.27 |
938+
| 16 | 8 | 126 and 256 | 512 and 512 | 4 and 2 | 128 | 79023 | 14.26 | 35526 | 14.63 | 2.22 |
939+
| 32 | 8 | 126 and 256 | 256 and 256 | 2 and 1 | 128 | 157952 | 28.51 | 69701 | 28.71 | 2.27 |
940+
| 1 | 8 | 16 and 32 | 4096 and 4096 | 256 and 128 | 512 | 2109 | 1 | 851 | 1 | 2.48 |
941+
| 2 | 8 | 16 and 32 | 2048 and 2048 | 128 and 64 | 512 | 4051 | 1.92 | 1601 | 1.88 | 2.53 |
942+
| 4 | 8 | 16 and 32 | 1024 and 1024 | 64 and 32 | 512 | 7972 | 3.78 | 3240 | 3.81 | 2.46 |
943+
| 8 | 8 | 16 and 32 | 512 and 512 | 32 and 16 | 512 | 15760 | 7.47 | 6329 | 7.44 | 2.49 |
944+
| 16 | 8 | 16 and 32 | 256 and 256 | 16 and 8 | 512 | 31129 | 14.76 | 12273 | 14.42 | 2.54 |
945+
| 32 | 8 | 16 and 32 | 128 and 128 | 8 and 4 | 512 | 60206 | 28.55 | 24047 | 28.26 | 2.50 |
939946

940947

941948
###### Fine-tuning NVIDIA DGX A100 (8x A100 80GB)
@@ -944,8 +951,8 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
944951

945952
| GPUs | Batch size / GPU (TF32 and FP16) | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
946953
|------|----------------------------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
947-
| 1 | 32 and 32 | 83 | 120 | 1.45 | 1.00 | 1.00 |
948-
| 8 | 32 and 32 | 629 | 876 | 1.39 | 7.59 | 7.30 |
954+
| 1 | 32 and 32 | 83 | 123 | 1.48 | 1.00 | 1.00 |
955+
| 8 | 32 and 32 | 629 | 929 | 1.48 | 7.59 | 7.55 |
949956

950957
#### Inference performance results
951958

@@ -983,6 +990,13 @@ August 2022
983990
- SQuAD finetune support with AdamW optimizer.
984991
- Updated accuracy and performance tables tested on A100.
985992
- Initial release.
993+
994+
March 2023
995+
- Pre-training using [Language Datasets and Data Loaders (LDDL)](https://github.com/NVIDIA/LDDL)
996+
- Binned pretraining for phase2 with LDDL using a bin size of 64
997+
998+
July 2023
999+
- Optimize AMP training with cuDNN fused dot product attention kernel.
9861000

9871001
### Known issues
9881002

PaddlePaddle/LanguageModeling/BERT/modeling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def __init__(self, bert_config):
172172
dropout=bert_config.hidden_dropout_prob,
173173
activation=bert_config.hidden_act,
174174
attn_dropout=bert_config.attention_probs_dropout_prob,
175-
act_dropout=0)
175+
act_dropout=0,
176+
fuse_qkv=bert_config.fuse_mha)
176177
self.encoder = nn.TransformerEncoder(encoder_layer,
177178
bert_config.num_hidden_layers)
178179

PaddlePaddle/LanguageModeling/BERT/program.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def create_pretraining_data_holder():
4444
]
4545

4646

47-
def create_strategy(use_amp, use_distributed_fused_lamb=False):
47+
def create_strategy(args, use_distributed_fused_lamb=False):
4848
"""
4949
Create paddle.static.BuildStrategy and paddle.static.ExecutionStrategy with arguments.
5050
5151
Args:
52-
use_amp(bool): Whether to use amp.
52+
args(Namespace): Arguments obtained from ArgumentParser.
5353
use_distributed_fused_lamb(bool, optional): Whether to use distributed fused lamb.
5454
Returns:
5555
build_strategy(paddle.static.BuildStrategy): A instance of BuildStrategy.
@@ -59,8 +59,9 @@ def create_strategy(use_amp, use_distributed_fused_lamb=False):
5959
exec_strategy = paddle.static.ExecutionStrategy()
6060

6161
build_strategy.enable_addto = True
62-
if use_amp:
62+
if args.amp:
6363
build_strategy.fuse_gemm_epilogue = True
64+
build_strategy.fuse_dot_product_attention = args.fuse_mha
6465

6566
if use_distributed_fused_lamb:
6667
build_strategy.fuse_all_reduce_ops = False
@@ -86,7 +87,7 @@ def dist_optimizer(args, optimizer):
8687
optimizer(fleet.distributed_optimizer): A distributed optimizer.
8788
"""
8889
use_distributed_fused_lamb = True if args.optimizer == 'DistributedFusedLamb' else False
89-
build_strategy, exec_strategy = create_strategy(args.amp,
90+
build_strategy, exec_strategy = create_strategy(args,
9091
use_distributed_fused_lamb)
9192
dist_strategy = fleet.DistributedStrategy()
9293

@@ -160,6 +161,7 @@ def build(args, main_prog, startup_prog, is_train=True):
160161
bert_config = BertConfig.from_json_file(args.config_file)
161162
if bert_config.vocab_size % 8 != 0:
162163
bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
164+
bert_config.fuse_mha = args.fuse_mha
163165
model = BertForPretraining(bert_config)
164166
criterion = BertPretrainingCriterion(bert_config.vocab_size)
165167
prediction_scores, seq_relationship_score = model(
@@ -224,19 +226,22 @@ def run(exe,
224226
logging.info(f"Training will start at the {last_step+1}th step")
225227

226228
max_steps = args.max_steps
229+
steps_this_run = max_steps
227230
if args.steps_this_run is not None:
228231
if args.steps_this_run + last_step > max_steps:
229232
logging.info(
230233
f"Only {max_steps - last_step} steps will be performed in this run due to the limit of --max-steps."
231234
)
232235
else:
233236
steps_this_run = args.steps_this_run
234-
if args.benchmark:
235-
steps_this_run = min(steps_this_run, args.benchmark_warmup_steps + args.benchmark_steps)
236237
max_steps = steps_this_run + last_step
237238
logging.warning(
238239
f"{steps_this_run} steps will be performed in this run.")
239240

241+
if args.benchmark:
242+
max_steps = args.benchmark_warmup_steps + args.benchmark_steps + last_step
243+
244+
240245
total_samples = 0
241246
raw_train_start = time.time()
242247
step_start = time.time()

PaddlePaddle/LanguageModeling/BERT/run_pretraining.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def main():
8181
log_dir=None if args.output_dir is None else
8282
os.path.join(args.output_dir, 'lddl_log'),
8383
log_level=logging.WARNING,
84-
start_epoch=0 if progress is None else progress.get("epoch", 0), )
84+
start_epoch=0 if progress is None else progress.get("epoch", 0),
85+
sequence_length_alignment=64)
8586

8687
if args.amp:
8788
optimizer.amp_init(device)

0 commit comments

Comments
 (0)