Skip to content

Commit 9a61730

Browse files
authored
support fairseq export (#278)
* fix torch fake quant positions * fixtorch decoder self attn quant position * rename scripts * modify lightseq arguments * add find-unused-parameters for ls_torch_fairseq training * finetune quant model from pretrained fp16 model * fairseq generate using scarebleu * support native fairseq export * polish export code * support converting pb to hdf5 * support ls_torch_fairseq_quant export (stage 1) * fix typo * fix fake quant relu compute bug * fix export bug * delete useless proto keys * add ls_torch_fairseq ptq export, fix encdec_attn kv quaant bug * fix qat export bug * modify ptq act_clip_max * support fairseq generate using lightseq inference * support native fairseq ptq export * modify README.md
1 parent 6988914 commit 9a61730

32 files changed

+2244
-586
lines changed

examples/inference/python/README.md

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,105 @@
1-
## Examples of exporting models for LightSeq inference
1+
# Examples of exporting models for LightSeq inference
22

3-
### Switch to the current directory
3+
## Switch to the current directory
44
```shell
55
cd examples/inference/python
66
```
77

8-
### Export models
8+
## Export models
9+
### Hugging Face
910
1. Hugging Face BART
1011

1112
Export Hugging Face BART models to protobuf/hdf5 format.
1213
```shell
13-
python export/hf_bart_export.py
14+
python export/huggingface/hf_bart_export.py
1415
```
1516
2. Hugging Face BERT
1617

1718
Export Hugging Face BERT models to hdf5 format.
1819
```shell
19-
python export/hf_bert_export.py
20+
python export/huggingface/hf_bert_export.py
2021
```
2122
3. Hugging Face GPT2
2223

2324
Export Hugging Face GPT2 models to hdf5 format.
2425
```shell
25-
python export/hf_gpt2_export.py
26+
python export/huggingface/hf_gpt2_export.py
2627
```
27-
4. Fairseq Transformer using LightSeq training library
28+
### Native Fairseq
29+
1. Native Fairseq Transformer
30+
31+
Export native Fairseq Transformer models to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
32+
```shell
33+
python export/fairseq/native_fs_transformer_export.py -m checkpoint_best.pt
34+
```
35+
36+
2. Native Fairseq Transformer using PTQ
37+
38+
Export native Fairseq Transformer models using PTQ to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
39+
```shell
40+
python export/fairseq/native_fs_transformer_export.py -m checkpoint_best.pt
41+
```
42+
43+
3. Native Fairseq MoE Transformer
44+
45+
Export Fairseq MoE models to protobuf/hdf5 format.
46+
```shell
47+
python export/fairseq/fs_moe_export.py
48+
```
49+
50+
### Fairseq Transformer + LightSeq
51+
1. Fairseq Transformer using LightSeq training library
2852

2953
Export Fairseq Transformer models training with LightSeq to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
3054
```shell
31-
python export/ls_fs_transformer_export.py
55+
python export/fairseq/ls_fs_transformer_export.py -m checkpoint_best.pt
3256
```
33-
5. Fairseq Transformer using LightSeq training library with int8 quantization
3457

35-
Export Fairseq Transformer models training with LightSeq to protobuf format, and then using int8 quantization to speedup inference. Refer to the `examples/training/fairseq` directory for more training details.
58+
2. Fairseq Transformer using LightSeq training library with PTQ
59+
60+
Export Fairseq Transformer models training with LightSeq to protobuf format, and then using PTQ to speedup inference. Refer to the `examples/training/fairseq` directory for more training details.
3661
```shell
37-
python export/ls_fs_transformer_ptq_export.py
62+
python export/fairseq/ls_fs_transformer_ptq_export.py -m checkpoint_best.pt
3863
```
39-
**You can compare the speeds between fp16 and int8 inference using above 4th and 5th examples.**
4064

41-
6. LightSeq Transformer
65+
### LightSeq Transformer
66+
67+
1. LightSeq Transformer
4268

4369
Export LightSeq Transformer models to protobuf/hdf5 format. Refer to the `examples/training/custom` directory for more training details.
4470
```shell
4571
python export/ls_transformer_export.py
4672
```
47-
7. LightSeq Transformer using int8 quantization
73+
2. LightSeq Transformer using PTQ
4874

49-
Export LightSeq fp16/fp32 Transformer models to int8 protobuf format, and then using int8 quantization to speedup inference. Refer to the `examples/training/custom` directory for more training details. Note that in this example, we do not need to finetune the models using fake-quantization.
75+
Export LightSeq fp16/fp32 Transformer models to int8 protobuf format, and then using PTQ to speedup inference. Refer to the `examples/training/custom` directory for more training details. Note that in this example, we do not need to finetune the models using fake-quantization.
5076
```shell
5177
python export/ls_transformer_ptq_export.py
5278
```
53-
**You can compare the speeds between fp16 and int8 inference using above 6th and 7th examples.**
5479

55-
8. Fairseq Transformer
80+
### Fairseq Transformer + custom Torch layers
81+
1. Fairseq Transformer using custom Torch layers
5682

57-
Export Fairseq Transformer models to protobuf/hdf5 format.
83+
Export Fairseq Transformer models training using custom Torch layers to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
5884
```shell
59-
python export/fs_transformer_export.py
85+
python export/fairseq/ls_torch_fs_transformer_export.py -m checkpoint_best.pt
6086
```
61-
9. Fairseq MoE
6287

63-
Export Fairseq MoE models to protobuf/hdf5 format.
88+
2. Fairseq Transformer using custom Torch layers and PTQ
89+
90+
Export PTQ Fairseq Transformer models training using custom Torch layers to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
91+
```shell
92+
python export/fairseq/ls_torch_fs_transformer_ptq_export.py -m checkpoint_best.pt
93+
```
94+
95+
3. Quantized Fairseq Transformer using custom Torch layers
96+
97+
Export quantized Fairseq Transformer models training using custom Torch layers to protobuf/hdf5 format. Refer to the `examples/training/fairseq` directory for more training details.
6498
```shell
65-
python export/fs_moe_export.py
99+
python export/fairseq/ls_torch_fs_quant_transformer_export.py -m checkpoint_best.pt
66100
```
67101

68-
### Inference using LightSeq
102+
## Inference using LightSeq
69103
1. BART
70104
```shell
71105
python test/ls_bart.py
@@ -78,3 +112,8 @@ python test/ls_bert.py
78112
```shell
79113
python test/ls_gpt2.py
80114
```
115+
116+
4. Fairseq based models using LightSeq inference
117+
```shell
118+
bash test/ls_fairseq.sh --model ${model_path}
119+
```

examples/inference/python/export/ls_fs_transformer_export.py renamed to examples/inference/python/export/fairseq/ls_fs_transformer_export.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
Export Fairseq Transformer models training with LightSeq to protobuf/hdf5 format.
33
Refer to the `examples/training/fairseq` directory for more training details.
44
"""
5+
import argparse
56
import torch
67
import h5py
7-
from proto.transformer_pb2 import Transformer
8+
from export.proto.transformer_pb2 import Transformer
89
from lightseq.training import (
910
export_ls_config,
1011
export_ls_embedding,
@@ -60,8 +61,8 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
6061
else:
6162
file = h5py.File(out_path, "w")
6263
encoder_state_dict, decoder_state_dict = _extract_weight(state_dict)
63-
export_ls_embedding(file, encoder_state_dict, 1024, True, save_pb)
64-
export_ls_embedding(file, decoder_state_dict, 1024, False, save_pb)
64+
export_ls_embedding(file, encoder_state_dict, 300, True, save_pb)
65+
export_ls_embedding(file, decoder_state_dict, 300, False, save_pb)
6566
export_ls_encoder(
6667
file,
6768
encoder_state_dict,
@@ -81,9 +82,9 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
8182
export_ls_config(
8283
file,
8384
args.encoder_attention_heads,
85+
1,
8486
2,
8587
2,
86-
6,
8788
args.encoder_layers,
8889
args.decoder_layers,
8990
save_pb=save_pb,
@@ -96,19 +97,33 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
9697
file.close()
9798

9899

100+
def parse_args():
101+
parser = argparse.ArgumentParser(description="export fairseq checkpoint", usage="")
102+
parser.add_argument(
103+
"--model",
104+
"-m",
105+
type=str,
106+
default="checkpoint_best.pt",
107+
help="path of fairseq checkpoint",
108+
)
109+
args = parser.parse_args()
110+
return args
111+
112+
99113
if __name__ == "__main__":
100-
ckpt_path = "checkpoint_best.pt"
101-
pb_path = "transformer.pb"
102-
hdf5_path = "transformer.hdf5"
114+
args = parse_args()
115+
model_name = ".".join(args.model.split(".")[:-1])
116+
pb_path = f"{model_name}.pb"
117+
hdf5_path = f"{model_name}.hdf5"
103118
print("export to pb model >>>>>>")
104-
export_ls_fs_transformer(ckpt_path, pb_path)
119+
export_ls_fs_transformer(args.model, pb_path)
105120
print("export to hdf5 model >>>>>>")
106-
export_ls_fs_transformer(ckpt_path, hdf5_path, save_pb=False)
107-
src = [[63, 47, 65, 1507, 88, 74, 10, 2057, 362, 9, 284, 6, 2]]
121+
export_ls_fs_transformer(args.model, hdf5_path, save_pb=False)
122+
src = [[63, 47, 65, 1507, 88, 74, 10, 2057, 362, 9, 284, 6, 2, 1, 1, 1]]
108123
pb_model = lsi.Transformer(pb_path, 8)
109124
pb_output = pb_model.infer(src)
110125
hdf5_model = lsi.Transformer(hdf5_path, 8)
111126
hdf5_output = hdf5_model.infer(src)
112-
# Expected result: [23, 550, 34, 118, 148, 2939, 4, 42, 32, 37, 6]
127+
# Expected result: [23, 550, 34, 118, 148, 2939, 4, 42, 32, 37, 6, 224, 10, 179, 5, 2]
113128
print("pb results:", pb_output)
114129
print("hdf5 results:", hdf5_output)

examples/inference/python/export/ls_fs_transformer_ptq_export.py renamed to examples/inference/python/export/fairseq/ls_fs_transformer_ptq_export.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
and then using int8 quantization to speedup inference.
44
Refer to the `examples/training/fairseq` directory for more training details.
55
"""
6+
import argparse
67
import torch
78
import h5py
8-
from proto.quant_transformer_pb2 import QuantTransformer
9+
from export.proto.quant_transformer_pb2 import QuantTransformer
910
from lightseq.training import (
1011
export_ls_config,
1112
export_ls_embedding_ptq,
@@ -47,7 +48,7 @@ def export_fs_weights(file, state_dict, save_pb=True):
4748
file.trg_embedding.shared_bias[:] = dec_shared_b
4849

4950

50-
def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
51+
def export_ls_fs_transformer_ptq(ckpt_path, out_path, save_pb=True):
5152
with open(ckpt_path, "rb") as fin:
5253
ckpt_file = torch.load(fin)
5354
args = ckpt_file["args"]
@@ -58,14 +59,14 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
5859
export_ls_embedding_ptq(
5960
file,
6061
encoder_state_dict,
61-
1024,
62+
300,
6263
True,
6364
save_pb=save_pb,
6465
)
6566
export_ls_embedding_ptq(
6667
file,
6768
decoder_state_dict,
68-
1024,
69+
300,
6970
False,
7071
save_pb=save_pb,
7172
)
@@ -90,9 +91,9 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
9091
export_ls_config(
9192
file,
9293
args.encoder_attention_heads,
94+
1,
9395
2,
9496
2,
95-
6,
9697
args.encoder_layers,
9798
args.decoder_layers,
9899
save_pb=save_pb,
@@ -102,13 +103,27 @@ def export_ls_fs_transformer(ckpt_path, out_path, save_pb=True):
102103
fout.write(file.SerializeToString())
103104

104105

106+
def parse_args():
107+
parser = argparse.ArgumentParser(description="export fairseq checkpoint", usage="")
108+
parser.add_argument(
109+
"--model",
110+
"-m",
111+
type=str,
112+
default="checkpoint_best.pt",
113+
help="path of fairseq checkpoint",
114+
)
115+
args = parser.parse_args()
116+
return args
117+
118+
105119
if __name__ == "__main__":
106-
ckpt_path = "checkpoint_best.pt"
107-
pb_path = "quant_transformer.pb"
120+
args = parse_args()
121+
model_name = ".".join(args.model.split(".")[:-1])
122+
pb_path = f"{model_name}_ptq.pb"
108123
print("export to pb model >>>>>>")
109-
export_ls_fs_transformer(ckpt_path, pb_path)
110-
src = [[63, 47, 65, 1507, 88, 74, 10, 2057, 362, 9, 284, 6, 2]]
124+
export_ls_fs_transformer_ptq(args.model, pb_path)
125+
src = [[63, 47, 65, 1507, 88, 74, 10, 2057, 362, 9, 284, 6, 2, 1, 1, 1]]
111126
pb_model = lsi.QuantTransformer(pb_path, 8)
112127
pb_output = pb_model.infer(src)
113-
# FP16 result: [23, 550, 34, 118, 148, 2939, 4, 42, 32, 37, 6]
128+
# FP16 result: [23, 550, 34, 118, 148, 2939, 4, 42, 32, 37, 6, 224, 10, 179, 5, 2]
114129
print("pb results:", pb_output)

0 commit comments

Comments
 (0)