Skip to content

Commit baa0a99

Browse files
authored
Add files via upload
1 parent 44d2bfc commit baa0a99

File tree

1 file changed

+292
-0
lines changed

1 file changed

+292
-0
lines changed

โ€Žtrain.py

+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import argparse
2+
import logging
3+
import math
4+
5+
import gluonnlp as nlp
6+
import mxnet as mx
7+
import pandas as pd
8+
from gluonnlp.data import SentencepieceTokenizer
9+
from kogpt2.mxnet_kogpt2 import get_mxnet_kogpt2_model
10+
from kogpt2.utils import get_tokenizer
11+
from mxnet import gluon, nd
12+
from mxnet.gluon import nn
13+
14+
parser = argparse.ArgumentParser(description='Simsimi based on KoGPT-2')
15+
16+
parser.add_argument('--num-epoch',
17+
type=int,
18+
default=1,
19+
help='number of iterations to train (default: 2)')
20+
21+
parser.add_argument('--max-seq-len',
22+
type=int,
23+
default=32,
24+
help='max sentence length on input (default: 32)')
25+
26+
parser.add_argument('--batch-size',
27+
type=int,
28+
default=64,
29+
help='batch size for training (default: 64)')
30+
31+
parser.add_argument('--chat',
32+
action='store_true',
33+
default=False,
34+
help='response generation on given user input')
35+
36+
parser.add_argument('--sentiment',
37+
type=str,
38+
default='0',
39+
help='sentiment for system. 0 is neutral, 1 is negative, 2 is positive.')
40+
41+
42+
parser.add_argument('--model_params',
43+
type=str,
44+
default='kogpt2_chat.params',
45+
help='model binary for starting chat')
46+
47+
parser.add_argument('--train',
48+
action='store_true',
49+
default=False,
50+
help='eval train set (default: False)')
51+
52+
53+
parser.add_argument(
54+
'--accumulate',
55+
type=int,
56+
default=1,
57+
help='accumulate gradient to achieve the same result with a large batch size')
58+
59+
opt = parser.parse_args()
60+
61+
logger = logging.getLogger()
62+
logger.setLevel(logging.INFO)
63+
64+
U_TKN = '<usr>'
65+
S_TKN = '<sys>'
66+
BOS = '<s>'
67+
EOS = '</s>'
68+
MASK = '<unused0>'
69+
SENT = '<unused1>'
70+
71+
72+
class ChatDataset(gluon.data.Dataset):
73+
def __init__(self, chats, tok_path, vocab, max_len=32):
74+
self._data = chats
75+
self._tok_path = tok_path
76+
self.tokenizer = None
77+
self.first = True
78+
self.q_token = U_TKN
79+
self.a_token = S_TKN
80+
self.sent_token = SENT
81+
self.bos = BOS
82+
self.eos = EOS
83+
self.maskt = MASK
84+
self.vocab = vocab
85+
self.max_len = max_len
86+
self.padder = nlp.data.PadSequence(
87+
max_len, pad_val=self.vocab[self.vocab.padding_token])
88+
89+
def _activate_sp(self):
90+
self.tokenizer = nlp.data.SentencepieceTokenizer(self._tok_path, 0, 0)
91+
92+
def __len__(self):
93+
return len(self._data)
94+
95+
def __getitem__(self, idx):
96+
if self.tokenizer is None:
97+
self._activate_sp()
98+
turn = self._data.iloc[idx]
99+
q = turn['Q']
100+
a = turn['A']
101+
sentiment = str(turn['label'])
102+
q_toked = [
103+
self.q_token,
104+
] + self.tokenizer(q) + [
105+
self.eos,
106+
] + [self.sent_token] + self.tokenizer(sentiment) + [
107+
self.eos,
108+
]
109+
q_len = len(q_toked)
110+
a_toked = [
111+
self.a_token,
112+
] + self.tokenizer(a) + [
113+
self.eos,
114+
]
115+
a_len = len(a_toked)
116+
if q_len + a_len > self.max_len:
117+
a_len = self.max_len - q_len
118+
if a_len <= 0:
119+
q_toked = q_toked[-(int(self.max_len/2)):]
120+
q_len = len(q_toked)
121+
a_len = self.max_len - q_len
122+
assert a_len > 0
123+
a_toked = a_toked[:a_len]
124+
a_len = len(a_toked)
125+
assert a_len == len(a_toked), f'{a_len} ==? {len(a_toked)}'
126+
# [<mask>, <mask>, ...., <mask>, ..., A.. <eos>, <pad>....]
127+
labels = [
128+
self.maskt,
129+
] * q_len + a_toked[1:]
130+
if self.first:
131+
logging.info("contexts : {}".format(q))
132+
logging.info("toked ctx: {}".format(q_toked))
133+
logging.info("response : {}".format(a))
134+
logging.info("toked response : {}".format(a_toked))
135+
logging.info('labels {}'.format(labels))
136+
self.first = False
137+
mask = [0] * q_len + [1] * a_len + [0] * (self.max_len - q_len - a_len)
138+
return (self.padder(self.vocab[q_toked + a_toked]), nd.array(mask),
139+
self.padder(self.vocab[labels]))
140+
141+
142+
class KoGPT2Chat(nn.HybridBlock):
143+
def __init__(self, kogpt2, prefix=None, params=None):
144+
super(KoGPT2Chat, self).__init__(prefix=prefix, params=params)
145+
self.kogpt2 = kogpt2
146+
147+
def hybrid_forward(self, F, inputs):
148+
# (batch, seq_len, hiddens)
149+
output, _ = self.kogpt2(inputs)
150+
return output
151+
152+
153+
if mx.context.num_gpus() > 0:
154+
ctx = mx.gpu()
155+
else:
156+
ctx = mx.cpu()
157+
158+
159+
def train():
160+
tok_path = get_tokenizer()
161+
model, vocab = get_mxnet_kogpt2_model(ctx=ctx)
162+
# tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
163+
164+
data = pd.read_csv("Chatbot_data/ChatbotData.csv")
165+
166+
max_len = opt.max_seq_len
167+
train_set = ChatDataset(data, tok_path, vocab, max_len=max_len)
168+
batch_size = opt.batch_size
169+
170+
train_dataloader = mx.gluon.data.DataLoader(train_set,
171+
batch_size=batch_size,
172+
num_workers=5,
173+
shuffle=True)
174+
kogptqa = KoGPT2Chat(model)
175+
kogptqa.hybridize()
176+
177+
# softmax cross entropy loss for classification
178+
loss_function = gluon.loss.SoftmaxCrossEntropyLoss()
179+
loss_function.hybridize()
180+
181+
num_epochs = opt.num_epoch
182+
lr = 5e-5
183+
trainer = gluon.Trainer(kogptqa.collect_params(), 'bertadam', {
184+
'learning_rate': lr,
185+
'epsilon': 1e-8,
186+
'wd': 0.01
187+
})
188+
# LayerNorm๊ณผ Bias์—๋Š” Weight Decay๋ฅผ ์ ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค.
189+
for _, v in kogptqa.collect_params('.*beta|.*gamma|.*bias').items():
190+
v.wd_mult = 0.0
191+
params = [
192+
p for p in kogptqa.collect_params().values() if p.grad_req != 'null'
193+
]
194+
# learning rate warmup
195+
accumulate = opt.accumulate
196+
step_size = batch_size * accumulate if accumulate else batch_size
197+
num_train_examples = len(train_set)
198+
num_train_steps = int(num_train_examples / step_size * num_epochs)
199+
warmup_ratio = 0.1
200+
num_warmup_steps = int(num_train_steps * warmup_ratio)
201+
step_num = 0
202+
all_model_params = kogptqa.collect_params()
203+
204+
log_interval = 200
205+
neg = -1e18
206+
# Set grad_req if gradient accumulation is required
207+
if accumulate and accumulate > 1:
208+
for p in params:
209+
p.grad_req = 'add'
210+
211+
for epoch_id in range(num_epochs):
212+
step_loss = 0
213+
for batch_id, (token_ids, mask, label) in enumerate(train_dataloader):
214+
if step_num < num_warmup_steps:
215+
new_lr = lr * step_num / num_warmup_steps
216+
else:
217+
non_warmup_steps = step_num - num_warmup_steps
218+
offset = non_warmup_steps / (num_train_steps -
219+
num_warmup_steps)
220+
new_lr = lr - offset * lr
221+
trainer.set_learning_rate(new_lr)
222+
with mx.autograd.record():
223+
# load data to GPU or GPU
224+
token_ids = token_ids.as_in_context(ctx)
225+
mask = mask.as_in_context(ctx)
226+
label = label.as_in_context(ctx)
227+
# forward computation
228+
out = kogptqa(token_ids)
229+
masked_out = nd.where(
230+
mask.expand_dims(axis=2).repeat(repeats=out.shape[2],
231+
axis=2), out,
232+
neg * nd.ones_like(out))
233+
# loss for responses exincluding MASK and PAD
234+
ls = loss_function(masked_out, label).sum() / mask.sum()
235+
# backward computation
236+
ls.backward()
237+
if not accumulate or (batch_id + 1) % accumulate == 0:
238+
trainer.allreduce_grads()
239+
nlp.utils.clip_grad_global_norm(params, 1)
240+
trainer.update(accumulate if accumulate else 1)
241+
step_num += 1
242+
if accumulate and accumulate > 1:
243+
# set grad to zero for gradient accumulation
244+
all_model_params.zero_grad()
245+
step_loss += ls.asscalar()
246+
if step_num % log_interval == 0 and step_num > 0:
247+
print(
248+
'[Step {} Epoch {} Batch {}/{}] loss={:.4f}, lr={:.10f}, perplexity={:.3f}'
249+
.format(step_num, epoch_id + 1, batch_id + 1, len(train_dataloader),
250+
step_loss / log_interval, trainer.learning_rate,
251+
2 ** (step_loss / log_interval))
252+
step_loss = 0
253+
logging.info('saving model file to {}'.format(opt.model_params))
254+
kogptqa.save_parameters(opt.model_params)
255+
256+
257+
def chat(model_params, sent='0'):
258+
tok_path = get_tokenizer()
259+
model, vocab = get_mxnet_kogpt2_model(ctx=ctx)
260+
tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
261+
kogptqa = KoGPT2Chat(model)
262+
kogptqa.load_parameters(model_params, ctx=ctx)
263+
sent_tokens = tok(sent)
264+
while 1:
265+
q = input('user > ').strip()
266+
if q == 'quit':
267+
break
268+
q_tok = tok(q)
269+
a = ''
270+
a_tok = []
271+
while 1:
272+
input_ids = mx.nd.array([vocab[U_TKN]] + vocab[q_tok] +
273+
vocab[EOS, SENT] + vocab[sent_tokens] +
274+
vocab[EOS, S_TKN] +
275+
vocab[a_tok]).expand_dims(axis=0)
276+
pred = kogptqa(input_ids.as_in_context(ctx))
277+
gen = vocab.to_tokens(
278+
mx.nd.argmax(
279+
pred,
280+
axis=-1).squeeze().astype('int').asnumpy().tolist())[-1]
281+
if gen == EOS:
282+
break
283+
a += gen.replace('โ–', ' ')
284+
a_tok = tok(a)
285+
print("Simsimi > {}".format(a.strip()))
286+
287+
288+
if __name__ == "__main__":
289+
if opt.train:
290+
train()
291+
if opt.chat:
292+
chat(opt.model_params, opt.sentiment)

0 commit comments

Comments
ย (0)