Skip to content
This repository was archived by the owner on Jul 21, 2025. It is now read-only.

Commit cc729e1

Browse files
godweiyangTaka152duanrenchong
authored
LightSeq QAT (#307)
* ls embedding support qat * [WIP]ls transformer qat * fix fairseq transformer cli shape bug of output projection * ln_bw_i8 test passed! * test with_mean of ln_i8 * ls encoder attn add qat * dropout_relu_bias_i8 passed! * dropout_gelu_bias unit test passed! * dropout_relu_bias_bwd_i8 passed! * dropout_gelu_bias_bwd_i8 unit test passed! * format * dropout_gelu_bias_bwd_i8 unit test passed! * format * polish unit test * [WIP] ls encoder qat test * quant_bias_add_transform_20314, quant_transform4d_0213 unit test passed! * fix unit test bug * [WIP] ls encoder qat unit test * fix bug * set default module to disable quant, fix bugs in examples * fix encoder bug * encoder qat test pass * decoder qat forward test pass * fix bug in encoder bw * fix bug of cmax grad * fix bug of act mask * fix bug in tensor quantizer * fix cmax grad bug * [WIP] decoder support qat * ls decoder qat pass * ls encoder qat pass * add unit test for quant bert encoder * fix memory bug * fix cmax grad bug in huggingface * quant bert enc fw&bw test passed! * fix hf cmax export bug * fix fairseq out_proj bug * fix fairseq shell bug * fix decoder mem bug * modify initial lr of fairseq quant training * decoupled qat code * modify huggingface training scripts * add cmax grad * delete enc_kv output quant * modify ffn2gemm quant like inference * fuse dequantize * fix post ln mem bug * add decoder self attn qkv cache quant * export quant model (stage 1) * export quant model (stage 2) * export quant model (stage 3) * support vit quant train * add gradient clip * fix hf export bug * fix quant gpt bug * support quant gpt training * modify huggingface training scripts * support ls bert, gpt export * support custom quant transformer export * optimizer ffn fake quant and dcmax * support quant gpt export * support quant vit export * add quant linear layer * fix quant linear layer bug * support quant vit infer * speedup cublass igemm on A100 (by huxingwu) * optimize ls_quant_dropout_act_bias_bwd_kernel * polish training gemm algo code * support gemm best algo search on different GPUs and shapes * search in the range (min_bsz, 512, 1) and (512, max_bsz, 32) * add configs_sm75/h512_i2048_b1-10016.json * support col32 igemm * add configs_sm75/h768_i3072_b1-10016.json * add configs_sm80/h512_i2048_b1-10016.json * add configs_sm75/h1024_i4096_b1-10016.json * add configs_sm80/h768_i3072_b1-10016.json * fix syntax error * configs_sm80/h1024_i4096_b1-10016.json * modify gemm test config format * merge all the configs to one * support search all shapes which are not in the config * polish the merged config * add cublas_algo_map cpp code * move get_sm func to lightseq kernels * move gemm_test to lightseq ops * modify default config dir, fix algo_map bug * fix col32 bug * col major igemm become default * fix dcax kernel bug * loosen cuda 11.6 requirement * add vit cpp example * fix bug from col32 gemm and a100 tuned col gemm * support training encoder qkv_linear auto-tune gemm (in comment) * add required header file * dynamic use col32 or col4 in different GPUs * fix multidefinition bug * fix weight transform col32 bug * add best algo for inference gemm (in comments) * support easy benchmark for gpt and transformer * support benmark huggingface * fix embedding clip_max bug * ls quant linear support more shape * fix quant linear bug * fix quant linear bug * update pad function for older torch * fix quant linear bug * remove redundant code * fix export bug * fix format * fix custom train&infer bug * fix quant infer size overflow * fix ls gpt export bug (extra_decode_length) * fix hf bart cmax init and state * fix max-batch-tokens bug of bart predict Co-authored-by: Ying Xiong <xiongying.taka@bytedance.com> Co-authored-by: duanrenchong <duanrenchong@bytedance.com>
1 parent ae569c2 commit cc729e1

File tree

149 files changed

+33917
-922
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+33917
-922
lines changed

CMakeLists.txt

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
22
project(LightSeq LANGUAGES C CXX CUDA)
33

4-
set(CMAKE_CUDA_ARCHITECTURES
5-
60
6-
61
7-
70
8-
75
9-
80
10-
86
11-
87)
12-
find_package(CUDA 11.6 REQUIRED)
4+
set(CMAKE_CUDA_ARCHITECTURES 60 61 70 75 80 86)
5+
find_package(CUDA 11 REQUIRED)
136

147
option(FP16_MODE "inference with fp16" OFF)
158
option(DEBUG_MODE "debug computation result" OFF)

docs/guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ The main code is as follows (some parameters are omitted). Complete code is avai
123123
```python
124124
model = Transformer()
125125
encoder_state_dict, decoder_state_dict = _extract_weight(state_dict)
126-
export_ls_embedding(model, encoder_state_dict, is_encoder=True)
127-
export_ls_embedding(model, encoder_state_dict, is_encoder=False)
126+
export_ls_embedding(model, encoder_state_dict, max_length, emb_dim, is_encoder=True)
127+
export_ls_embedding(model, encoder_state_dict, max_length, emb_dim, is_encoder=False)
128128
export_ls_encoder(model, encoder_state_dict)
129129
export_ls_decoder(model, decoder_state_dict)
130130
export_fs_weights(model, state_dict)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
SCRIPT=$(realpath "$0")
4+
CUR_DIR=$(dirname "$SCRIPT")
5+
6+
model_full_name=facebook/bart-base
7+
model_name=$(echo $model_full_name | cut -d "/" -f 2)
8+
all_log=$CUR_DIR/${model_name}_bench.log
9+
res_log=$CUR_DIR/${model_name}_bench.txt
10+
if [ -f $res_log ]; then
11+
rm $res_log
12+
fi
13+
if [ -f $all_log ]; then
14+
rm $all_log
15+
fi
16+
echo "batch_size input_seq_len output_seq_len beam_size latency" >>$res_log
17+
18+
for batch_size in 1 8 32; do
19+
for beam_size in 1 4 32; do
20+
for input_seq_len in 8 16 32 64; do
21+
output_seq_len=$input_seq_len
22+
cd $CUR_DIR/python
23+
24+
python3 generate_model.py --model_name $model_full_name --sampling_method beam_search \
25+
--beam_size $beam_size --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
26+
model_path=$(realpath lightseq_${model_name}_bench.hdf5)
27+
28+
cd $CUR_DIR/../../build
29+
./examples/inference/cpp/transformer_example \
30+
$model_path $batch_size $input_seq_len |& tee temp.log
31+
32+
cat temp.log >>$all_log
33+
latency=$(tail -n 5 temp.log | head -n 1 | awk '{print $4}')
34+
echo "$batch_size $input_seq_len $output_seq_len $beam_size $latency" >>$res_log
35+
rm temp.log
36+
done
37+
done
38+
done
39+
pip3 install tabulate
40+
tabulate --header $res_log
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
SCRIPT=$(realpath "$0")
4+
CUR_DIR=$(dirname "$SCRIPT")
5+
6+
model_full_name=gpt2
7+
model_name=$model_full_name
8+
all_log=$CUR_DIR/${model_name}_bench.log
9+
res_log=$CUR_DIR/${model_name}_bench.txt
10+
if [ -f $res_log ]; then
11+
rm $res_log
12+
fi
13+
if [ -f $all_log ]; then
14+
rm $all_log
15+
fi
16+
echo "batch_size input_seq_len output_seq_len topk latency" >>$res_log
17+
18+
for batch_size in 1 8 32; do
19+
for topk in 1 4 32; do
20+
for input_seq_len in 118 86 22; do
21+
output_seq_len=$((150 - $input_seq_len))
22+
cd $CUR_DIR/python
23+
24+
python3 generate_model.py --model_name $model_full_name --sampling_method topk \
25+
--topk $topk --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
26+
model_path=$(realpath lightseq_${model_name}_bench.hdf5)
27+
28+
cd $CUR_DIR/../../build
29+
./examples/inference/cpp/gpt_example \
30+
$model_path $batch_size $input_seq_len |& tee temp.log
31+
32+
cat temp.log >>$all_log
33+
latency=$(tail -n 3 temp.log | head -n 1 | awk '{print $4}')
34+
echo "$batch_size $input_seq_len $output_seq_len $topk $latency" >>$res_log
35+
rm temp.log
36+
done
37+
done
38+
done
39+
pip3 install tabulate
40+
tabulate --header $res_log
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
3+
SCRIPT=$(realpath "$0")
4+
CUR_DIR=$(dirname "$SCRIPT")
5+
6+
model_full_name=facebook/bart-base
7+
model_name=$(echo $model_full_name | cut -d "/" -f 2)
8+
all_log=$CUR_DIR/quant_${model_name}_bench.log
9+
res_log=$CUR_DIR/quant_${model_name}_bench.txt
10+
if [ -f $all_log ]; then
11+
rm $res_log
12+
fi
13+
if [ -f $res_log ]; then
14+
rm $res_log
15+
fi
16+
echo "batch_size input_seq_len output_seq_len beam_size latency" >>$res_log
17+
18+
for batch_size in 1 8 32; do
19+
for beam_size in 1 4 32; do
20+
for input_seq_len in 16 32 64; do
21+
output_seq_len=$input_seq_len
22+
cd $CUR_DIR/python
23+
24+
python3 generate_model.py --model_name $model_full_name --sampling_method beam_search \
25+
--beam_size $beam_size --input_seq_len $input_seq_len --output_seq_len=$output_seq_len
26+
model_path=$(realpath lightseq_${model_name}_bench.hdf5)
27+
28+
cd $CUR_DIR/../../build
29+
./examples/inference/cpp/quant_transformer_example \
30+
$model_path $batch_size $input_seq_len |& tee temp.log
31+
32+
cat temp.log >> $all_log
33+
latency=$(tail -n 5 temp.log | head -n 1 | awk '{print $4}')
34+
echo "$batch_size $input_seq_len $output_seq_len $beam_size $latency" >>$res_log
35+
rm temp.log
36+
done
37+
done
38+
done
39+
40+
pip3 install tabulate
41+
tabulate --header $res_log
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
3+
SCRIPT=$(realpath "$0")
4+
CUR_DIR=$(dirname "$SCRIPT")
5+
6+
model_full_name=/tmp/quant/test-clm/pytorch_model.bin
7+
model_name=quant_gpt2
8+
all_log=$CUR_DIR/${model_name}_bench.log
9+
res_log=$CUR_DIR/${model_name}_bench.txt
10+
if [ -f $res_log ]; then
11+
rm $res_log
12+
fi
13+
if [ -f $all_log ]; then
14+
rm $all_log
15+
fi
16+
echo "batch_size input_seq_len output_seq_len topk latency" >>$res_log
17+
18+
for batch_size in 1 8 32; do
19+
for topk in 1 4 32; do
20+
for input_seq_len in 118 86 22; do
21+
output_seq_len=$((150 - $input_seq_len))
22+
cd $CUR_DIR/python
23+
24+
python3 generate_model.py --model_name $model_full_name --sampling_method topk \
25+
--topk $topk --input_seq_len $input_seq_len --output_seq_len=$output_seq_len --enable_quant true
26+
model_path=$(realpath lightseq_${model_name}_bench.hdf5)
27+
28+
cd $CUR_DIR/../../build
29+
./examples/inference/cpp/quant_gpt_example \
30+
$model_path $batch_size $input_seq_len |& tee temp.log
31+
32+
cat temp.log >>$all_log
33+
latency=$(tail -n 3 temp.log | head -n 1 | awk '{print $4}')
34+
echo "$batch_size $input_seq_len $output_seq_len $topk $latency" >>$res_log
35+
rm temp.log
36+
done
37+
done
38+
done
39+
pip3 install tabulate
40+
tabulate --header $res_log

examples/inference/cpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,9 @@ target_link_libraries(quant_gpt_example PUBLIC liblightseq)
2020

2121
add_executable(transformer_decoder_example decoder_example.cc.cu)
2222
target_link_libraries(transformer_decoder_example PUBLIC transformer_model)
23+
24+
add_executable(vit_example vit_example.cc)
25+
target_link_libraries(vit_example PUBLIC liblightseq)
26+
27+
add_executable(quant_vit_example quant_vit_example.cc)
28+
target_link_libraries(quant_vit_example PUBLIC liblightseq)

examples/inference/cpp/gpt_example.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@ int main(int argc, char* argv[]) {
1010
std::string model_weights_path = argv[1];
1111
std::vector<int> example_input = {40, 1842, 345, 11, 475, 345, 910, 326};
1212
int eg_seq_len = example_input.size();
13-
int max_batch_size = 128;
13+
1414
int batch_size = 1;
1515
int batch_seq_len = eg_seq_len;
1616

1717
if (argc == 4) {
1818
batch_size = atoi(argv[2]);
1919
batch_seq_len = atoi(argv[3]);
2020
}
21-
if (batch_size > max_batch_size) {
22-
throw std::runtime_error("batch_size exceeds the maximum (128)!");
23-
}
21+
22+
int max_batch_size = std::max(8, batch_size);
2423

2524
std::vector<int> host_input;
2625
for (int i = 0; i < batch_size; ++i) {
@@ -39,6 +38,7 @@ int main(int argc, char* argv[]) {
3938
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
4039
cudaMemcpyHostToDevice));
4140

41+
model->benchmark_mode(true);
4242
model->set_input_ptr(0, d_input);
4343
model->set_input_shape(0, {batch_size, batch_seq_len});
4444

@@ -56,13 +56,22 @@ int main(int argc, char* argv[]) {
5656
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
5757
std::cout << "infer preprocessing finished" << std::endl;
5858

59+
std::chrono::duration<double> elapsed;
60+
int iter = 0;
5961
/* ---step5. infer and log--- */
60-
for (int i = 0; i < 10; i++) {
62+
for (int i = 0; i < 20; i++) {
6163
auto start = std::chrono::high_resolution_clock::now();
6264
model->Infer();
63-
lightseq::cuda::print_time_duration(start, "one infer time", 0);
65+
auto finish = std::chrono::high_resolution_clock::now();
66+
if (i >= 5) {
67+
iter++;
68+
elapsed += finish - start;
69+
}
6470
}
6571

72+
std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
73+
<< " ms" << std::endl;
74+
6675
for (int i = 0; i < model->get_output_size(); i++) {
6776
const int* d_output;
6877
d_output = static_cast<const int*>(model->get_output_ptr(i));

examples/inference/cpp/quant_gpt_example.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ int main(int argc, char* argv[]) {
1010
std::string model_weights_path = argv[1];
1111
std::vector<int> example_input = {40, 1842, 345, 11, 475, 345, 910, 326};
1212
int eg_seq_len = example_input.size();
13-
int max_batch_size = 128;
1413
int batch_size = 1;
1514
int batch_seq_len = eg_seq_len;
1615

1716
if (argc == 4) {
1817
batch_size = atoi(argv[2]);
1918
batch_seq_len = atoi(argv[3]);
2019
}
20+
21+
int max_batch_size = std::max(4, batch_size);
22+
2123
if (batch_size > max_batch_size) {
2224
throw std::runtime_error("batch_size exceeds the maximum (128)!");
2325
}
@@ -39,6 +41,7 @@ int main(int argc, char* argv[]) {
3941
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
4042
cudaMemcpyHostToDevice));
4143

44+
model->benchmark_mode(true);
4245
model->set_input_ptr(0, d_input);
4346
model->set_input_shape(0, {batch_size, batch_seq_len});
4447

@@ -56,13 +59,22 @@ int main(int argc, char* argv[]) {
5659
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
5760
std::cout << "infer preprocessing finished" << std::endl;
5861

62+
std::chrono::duration<double> elapsed;
63+
int iter = 0;
5964
/* ---step5. infer and log--- */
60-
for (int i = 0; i < 10; i++) {
65+
for (int i = 0; i < 20; i++) {
6166
auto start = std::chrono::high_resolution_clock::now();
6267
model->Infer();
63-
lightseq::cuda::print_time_duration(start, "one infer time", 0);
68+
auto finish = std::chrono::high_resolution_clock::now();
69+
if (i >= 5) {
70+
iter++;
71+
elapsed += finish - start;
72+
}
6473
}
6574

75+
std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
76+
<< " ms" << std::endl;
77+
6678
for (int i = 0; i < model->get_output_size(); i++) {
6779
const int* d_output;
6880
d_output = static_cast<const int*>(model->get_output_ptr(i));

examples/inference/cpp/quant_transformer_example.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@ int main(int argc, char* argv[]) {
1212
std::vector<int> example_input = {63, 47, 65, 1507, 88, 74,
1313
10, 2057, 362, 9, 284, 6};
1414
int eg_seq_len = example_input.size();
15-
int max_batch_size = 128;
1615
int batch_size = 1;
1716
int batch_seq_len = eg_seq_len;
1817

1918
if (argc == 4) {
2019
batch_size = atoi(argv[2]);
2120
batch_seq_len = atoi(argv[3]);
2221
}
23-
if (batch_size > max_batch_size) {
24-
throw std::runtime_error("batch_size exceeds the maximum (128)!");
25-
}
22+
int max_batch_size = std::max(4, batch_size);
2623

2724
std::vector<int> host_input;
2825
for (int i = 0; i < batch_size; ++i) {
@@ -41,6 +38,7 @@ int main(int argc, char* argv[]) {
4138
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
4239
cudaMemcpyHostToDevice));
4340

41+
model->benchmark_mode(true);
4442
model->set_input_ptr(0, d_input);
4543
model->set_input_shape(0, {batch_size, batch_seq_len});
4644

@@ -58,13 +56,22 @@ int main(int argc, char* argv[]) {
5856
lightseq::cuda::CHECK_GPU_ERROR(cudaStreamSynchronize(0));
5957
std::cout << "infer preprocessing finished" << std::endl;
6058

59+
std::chrono::duration<double> elapsed;
60+
int iter = 0;
6161
/* ---step5. infer and log--- */
6262
for (int i = 0; i < 20; i++) {
6363
auto start = std::chrono::high_resolution_clock::now();
6464
model->Infer();
65-
lightseq::cuda::print_time_duration(start, "one infer time", 0);
65+
auto finish = std::chrono::high_resolution_clock::now();
66+
if (i >= 5) {
67+
iter++;
68+
elapsed += finish - start;
69+
}
6670
}
6771

72+
std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
73+
<< " ms" << std::endl;
74+
6875
for (int i = 0; i < model->get_output_size(); i++) {
6976
const void* d_output;
7077
d_output = static_cast<const float*>(model->get_output_ptr(i));
@@ -76,9 +83,9 @@ int main(int argc, char* argv[]) {
7683
std::cout << std::endl;
7784

7885
if (!i)
79-
lightseq::cuda::print_vec((int*)d_output, "output", 15);
86+
lightseq::cuda::print_vec((int*)d_output, "output", batch_size);
8087
else
81-
lightseq::cuda::print_vec((float*)d_output, "output", 5);
88+
lightseq::cuda::print_vec((float*)d_output, "output", batch_size);
8289
}
8390

8491
// const int* res = model.get_result_ptr();

0 commit comments

Comments
 (0)