-
Notifications
You must be signed in to change notification settings - Fork 331
/
Copy pathls_quant_gpt2.py
252 lines (204 loc) · 8.43 KB
/
ls_quant_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import time
import torch
from torch import nn
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import lightseq.inference as lsi
from lightseq.training.ops.pytorch.quantization import (
qat_mode,
QuantLinear,
TensorQuantizer,
weight_quant_config,
emb_quant_config,
)
from lightseq.training.ops.pytorch.torch_transformer_layers import (
TransformerDecoderLayer,
)
from export.util import parse_args
def ls_gpt2(model, inputs, generation_method="topk"):
torch.cuda.synchronize()
start_time = time.perf_counter()
results = None
if generation_method == "topk" or generation_method == "topp":
results = model.sample(inputs)
elif generation_method == "ppl":
results = model.ppl(inputs)[0]
torch.cuda.synchronize()
end_time = time.perf_counter()
return results, end_time - start_time
def compute_hf_ppl(model, inputs):
max_length = 512
stride = 512
end_loc = 0
nlls = []
for i in range(0, inputs.size(1), stride):
begin_loc = max(i + stride - max_length, 0)
end_loc = min(i + stride, inputs.size(1))
trg_len = end_loc - i
input_ids = inputs[:, begin_loc:end_loc].to("cuda:0")
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
neg_log_likelihood = outputs[0] * trg_len
nlls.append(neg_log_likelihood)
ppl = torch.stack(nlls).sum() / end_loc
return ppl.cpu().numpy()
def hf_gpt2(model, inputs, tokenizer, generation_method="topk"):
inputs = inputs.to("cuda:0")
torch.cuda.synchronize()
start_time = time.perf_counter()
results = None
if generation_method == "topk" or generation_method == "topp":
results = model.generate(
inputs, max_length=50, pad_token_id=tokenizer.eos_token_id
)
elif generation_method == "ppl":
results = compute_hf_ppl(model, inputs)
torch.cuda.synchronize()
end_time = time.perf_counter()
return results, end_time - start_time
def ls_generate(model, tokenizer, inputs):
print("=========lightseq=========")
print("lightseq generating...")
ls_res_ids, ls_time = ls_gpt2(model, inputs)
ls_res = tokenizer.batch_decode(ls_res_ids, skip_special_tokens=True)
print(f"lightseq time: {ls_time}s")
print("lightseq results:")
for sent in ls_res:
print(sent)
def hf_generate(model, tokenizer, inputs):
print("=========huggingface=========")
print("huggingface generating...")
hf_res_ids, hf_time = hf_gpt2(model, inputs, tokenizer)
hf_res = tokenizer.batch_decode(hf_res_ids, skip_special_tokens=True)
print(f"huggingface time: {hf_time}s")
print("huggingface results:")
for sent in hf_res:
print(sent)
def ls_ppl(model, tokenizer, inputs):
print("=========lightseq=========")
print("lightseq calculating ppl...")
ls_ppl, ls_time = ls_gpt2(model, inputs, "ppl")
print(f"lightseq time: {ls_time}s")
print("lightseq results:")
print(ls_ppl)
def hf_ppl(model, tokenizer, inputs):
print("=========huggingface=========")
print("huggingface calculating ppl...")
hf_ppl, hf_time = hf_gpt2(model, inputs, tokenizer, "ppl")
print(f"huggingface time: {hf_time}s")
print("huggingface results:")
print(hf_ppl)
def warmup(
ls_tokenizer, hf_tokenizer, ls_model, hf_model, sentences, generation_method
):
ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
hf_inputs = hf_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
if generation_method == "topk" or generation_method == "topp":
ls_generate(ls_model, ls_tokenizer, ls_inputs)
# hf_generate(hf_model, hf_tokenizer, hf_inputs)
elif generation_method == "ppl":
ls_ppl(ls_model, ls_tokenizer, ls_inputs)
hf_ppl(hf_model, hf_tokenizer, hf_inputs)
class GptEmbedding(nn.Embedding):
def __init__(self, *args, **kwargs):
super(GptEmbedding, self).__init__(*args, **kwargs)
self.emb_quant = TensorQuantizer(emb_quant_config)
def forward(self, input_ids):
x = super(GptEmbedding, self).forward(input_ids)
x = self.emb_quant(x)
return x
def gen_gpt_enc_config(config):
gpt_enc_config = TransformerDecoderLayer.get_config(
max_batch_tokens=8192,
max_seq_len=config.max_position_embeddings,
hidden_size=config.hidden_size,
intermediate_size=4 * config.hidden_size,
nhead=config.num_attention_heads,
attn_prob_dropout_ratio=config.attn_pdrop,
activation_dropout_ratio=config.resid_pdrop,
hidden_dropout_ratio=config.resid_pdrop,
pre_layer_norm=True,
fp16=True,
local_rank=0,
nlayer=config.num_hidden_layers,
activation_fn="gelu",
has_cross_attn=False,
)
return gpt_enc_config
class LSHFGptEncoderLayer(TransformerDecoderLayer):
def __init__(self, *args, **kwargs):
super(LSHFGptEncoderLayer, self).__init__(*args, **kwargs)
def forward(self, hidden_states, attention_mask=None, *args, **kwargs):
if attention_mask is not None:
ls_attention_mask = attention_mask.squeeze()
else:
ls_attention_mask = torch.zeros(hidden_states.size()[:2])
output = super().forward(hidden_states, ls_attention_mask)
return output
def inject_ls_layer(model, config):
model.transformer.wte = GptEmbedding(config.vocab_size, config.hidden_size)
model.transformer.wte.apply(qat_mode)
for i in range(config.num_hidden_layers):
gpt_enc_config = gen_gpt_enc_config(config)
model.transformer.h[i] = LSHFGptEncoderLayer(gpt_enc_config).cuda()
model.transformer.h[i].apply(qat_mode)
q_lm_head = QuantLinear(config.n_embd, config.vocab_size, bias=False)
q_lm_head.weight = model.transformer.wte.weight
q_lm_head.weight_quant = model.transformer.wte.emb_quant
model.lm_head = q_lm_head
def main():
args = parse_args()
if args.generation_method not in ["topk", "topp", "ppl"]:
args.generation_method = "topk"
model_name = ".".join(args.model.split(".")[:-1])
ckpt_path = f"{model_name}.bin"
print("initializing gpt2 config...")
config = GPT2Config.from_pretrained("gpt2")
print("initializing gpt2 tokenizer...")
ls_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# lightseq use len(tokenizer) as pad_token in default
ls_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
print(f"lightseq tokenizer pad token id: {ls_tokenizer.pad_token_id}")
hf_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# use EOS as PAD for huggingface to avoid warning according to https://huggingface.co/blog/how-to-generate while avoid reshaping the model embedding
hf_tokenizer.pad_token = hf_tokenizer.eos_token
print(f"huggingface tokenizer pad token id: {hf_tokenizer.pad_token_id}")
print("creating huggingface model...")
hf_model = GPT2LMHeadModel.from_pretrained("gpt2", config=config)
inject_ls_layer(hf_model, config)
state_dict = torch.load(ckpt_path, map_location="cpu")
hf_model.load_state_dict(state_dict, strict=False)
hf_model.to("cuda:0")
hf_model.eval()
print("creating lightseq model...")
ls_model = lsi.QuantGpt(args.model, max_batch_size=16)
# lightseq gpt perplexity supports batch infer with different lengths,
# but sampling doesn't support
sentences = [
"I love you, but you say that",
"I love you, but you say that",
"I love you, but you say that",
"I love you, but you say that",
]
print("====================START warmup====================")
warmup(
ls_tokenizer,
hf_tokenizer,
ls_model,
hf_model,
sentences,
args.generation_method,
)
print("====================END warmup====================")
print("tokenizing the sentences...")
ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
hf_inputs = hf_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
if args.generation_method == "topk" or args.generation_method == "topp":
ls_generate(ls_model, ls_tokenizer, ls_inputs)
# hf_generate(hf_model, hf_tokenizer, hf_inputs)
elif args.generation_method == "ppl":
ls_ppl(ls_model, ls_tokenizer, ls_inputs)
hf_ppl(hf_model, hf_tokenizer, hf_inputs)
if __name__ == "__main__":
main()