|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import math |
| 4 | + |
| 5 | +import gluonnlp as nlp |
| 6 | +import mxnet as mx |
| 7 | +import pandas as pd |
| 8 | +from gluonnlp.data import SentencepieceTokenizer |
| 9 | +from kogpt2.mxnet_kogpt2 import get_mxnet_kogpt2_model |
| 10 | +from kogpt2.utils import get_tokenizer |
| 11 | +from mxnet import gluon, nd |
| 12 | +from mxnet.gluon import nn |
| 13 | + |
| 14 | +parser = argparse.ArgumentParser(description='Simsimi based on KoGPT-2') |
| 15 | + |
| 16 | +parser.add_argument('--num-epoch', |
| 17 | + type=int, |
| 18 | + default=1, |
| 19 | + help='number of iterations to train (default: 2)') |
| 20 | + |
| 21 | +parser.add_argument('--max-seq-len', |
| 22 | + type=int, |
| 23 | + default=32, |
| 24 | + help='max sentence length on input (default: 32)') |
| 25 | + |
| 26 | +parser.add_argument('--batch-size', |
| 27 | + type=int, |
| 28 | + default=64, |
| 29 | + help='batch size for training (default: 64)') |
| 30 | + |
| 31 | +parser.add_argument('--chat', |
| 32 | + action='store_true', |
| 33 | + default=False, |
| 34 | + help='response generation on given user input') |
| 35 | + |
| 36 | +parser.add_argument('--sentiment', |
| 37 | + type=str, |
| 38 | + default='0', |
| 39 | + help='sentiment for system. 0 is neutral, 1 is negative, 2 is positive.') |
| 40 | + |
| 41 | + |
| 42 | +parser.add_argument('--model_params', |
| 43 | + type=str, |
| 44 | + default='kogpt2_chat.params', |
| 45 | + help='model binary for starting chat') |
| 46 | + |
| 47 | +parser.add_argument('--train', |
| 48 | + action='store_true', |
| 49 | + default=False, |
| 50 | + help='eval train set (default: False)') |
| 51 | + |
| 52 | + |
| 53 | +parser.add_argument( |
| 54 | + '--accumulate', |
| 55 | + type=int, |
| 56 | + default=1, |
| 57 | + help='accumulate gradient to achieve the same result with a large batch size') |
| 58 | + |
| 59 | +opt = parser.parse_args() |
| 60 | + |
| 61 | +logger = logging.getLogger() |
| 62 | +logger.setLevel(logging.INFO) |
| 63 | + |
| 64 | +U_TKN = '<usr>' |
| 65 | +S_TKN = '<sys>' |
| 66 | +BOS = '<s>' |
| 67 | +EOS = '</s>' |
| 68 | +MASK = '<unused0>' |
| 69 | +SENT = '<unused1>' |
| 70 | + |
| 71 | + |
| 72 | +class ChatDataset(gluon.data.Dataset): |
| 73 | + def __init__(self, chats, tok_path, vocab, max_len=32): |
| 74 | + self._data = chats |
| 75 | + self._tok_path = tok_path |
| 76 | + self.tokenizer = None |
| 77 | + self.first = True |
| 78 | + self.q_token = U_TKN |
| 79 | + self.a_token = S_TKN |
| 80 | + self.sent_token = SENT |
| 81 | + self.bos = BOS |
| 82 | + self.eos = EOS |
| 83 | + self.maskt = MASK |
| 84 | + self.vocab = vocab |
| 85 | + self.max_len = max_len |
| 86 | + self.padder = nlp.data.PadSequence( |
| 87 | + max_len, pad_val=self.vocab[self.vocab.padding_token]) |
| 88 | + |
| 89 | + def _activate_sp(self): |
| 90 | + self.tokenizer = nlp.data.SentencepieceTokenizer(self._tok_path, 0, 0) |
| 91 | + |
| 92 | + def __len__(self): |
| 93 | + return len(self._data) |
| 94 | + |
| 95 | + def __getitem__(self, idx): |
| 96 | + if self.tokenizer is None: |
| 97 | + self._activate_sp() |
| 98 | + turn = self._data.iloc[idx] |
| 99 | + q = turn['Q'] |
| 100 | + a = turn['A'] |
| 101 | + sentiment = str(turn['label']) |
| 102 | + q_toked = [ |
| 103 | + self.q_token, |
| 104 | + ] + self.tokenizer(q) + [ |
| 105 | + self.eos, |
| 106 | + ] + [self.sent_token] + self.tokenizer(sentiment) + [ |
| 107 | + self.eos, |
| 108 | + ] |
| 109 | + q_len = len(q_toked) |
| 110 | + a_toked = [ |
| 111 | + self.a_token, |
| 112 | + ] + self.tokenizer(a) + [ |
| 113 | + self.eos, |
| 114 | + ] |
| 115 | + a_len = len(a_toked) |
| 116 | + if q_len + a_len > self.max_len: |
| 117 | + a_len = self.max_len - q_len |
| 118 | + if a_len <= 0: |
| 119 | + q_toked = q_toked[-(int(self.max_len/2)):] |
| 120 | + q_len = len(q_toked) |
| 121 | + a_len = self.max_len - q_len |
| 122 | + assert a_len > 0 |
| 123 | + a_toked = a_toked[:a_len] |
| 124 | + a_len = len(a_toked) |
| 125 | + assert a_len == len(a_toked), f'{a_len} ==? {len(a_toked)}' |
| 126 | + # [<mask>, <mask>, ...., <mask>, ..., A.. <eos>, <pad>....] |
| 127 | + labels = [ |
| 128 | + self.maskt, |
| 129 | + ] * q_len + a_toked[1:] |
| 130 | + if self.first: |
| 131 | + logging.info("contexts : {}".format(q)) |
| 132 | + logging.info("toked ctx: {}".format(q_toked)) |
| 133 | + logging.info("response : {}".format(a)) |
| 134 | + logging.info("toked response : {}".format(a_toked)) |
| 135 | + logging.info('labels {}'.format(labels)) |
| 136 | + self.first = False |
| 137 | + mask = [0] * q_len + [1] * a_len + [0] * (self.max_len - q_len - a_len) |
| 138 | + return (self.padder(self.vocab[q_toked + a_toked]), nd.array(mask), |
| 139 | + self.padder(self.vocab[labels])) |
| 140 | + |
| 141 | + |
| 142 | +class KoGPT2Chat(nn.HybridBlock): |
| 143 | + def __init__(self, kogpt2, prefix=None, params=None): |
| 144 | + super(KoGPT2Chat, self).__init__(prefix=prefix, params=params) |
| 145 | + self.kogpt2 = kogpt2 |
| 146 | + |
| 147 | + def hybrid_forward(self, F, inputs): |
| 148 | + # (batch, seq_len, hiddens) |
| 149 | + output, _ = self.kogpt2(inputs) |
| 150 | + return output |
| 151 | + |
| 152 | + |
| 153 | +if mx.context.num_gpus() > 0: |
| 154 | + ctx = mx.gpu() |
| 155 | +else: |
| 156 | + ctx = mx.cpu() |
| 157 | + |
| 158 | + |
| 159 | +def train(): |
| 160 | + tok_path = get_tokenizer() |
| 161 | + model, vocab = get_mxnet_kogpt2_model(ctx=ctx) |
| 162 | + # tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0) |
| 163 | + |
| 164 | + data = pd.read_csv("Chatbot_data/ChatbotData.csv") |
| 165 | + |
| 166 | + max_len = opt.max_seq_len |
| 167 | + train_set = ChatDataset(data, tok_path, vocab, max_len=max_len) |
| 168 | + batch_size = opt.batch_size |
| 169 | + |
| 170 | + train_dataloader = mx.gluon.data.DataLoader(train_set, |
| 171 | + batch_size=batch_size, |
| 172 | + num_workers=5, |
| 173 | + shuffle=True) |
| 174 | + kogptqa = KoGPT2Chat(model) |
| 175 | + kogptqa.hybridize() |
| 176 | + |
| 177 | + # softmax cross entropy loss for classification |
| 178 | + loss_function = gluon.loss.SoftmaxCrossEntropyLoss() |
| 179 | + loss_function.hybridize() |
| 180 | + |
| 181 | + num_epochs = opt.num_epoch |
| 182 | + lr = 5e-5 |
| 183 | + trainer = gluon.Trainer(kogptqa.collect_params(), 'bertadam', { |
| 184 | + 'learning_rate': lr, |
| 185 | + 'epsilon': 1e-8, |
| 186 | + 'wd': 0.01 |
| 187 | + }) |
| 188 | + # LayerNorm๊ณผ Bias์๋ Weight Decay๋ฅผ ์ ์ฉํ์ง ์๋๋ค. |
| 189 | + for _, v in kogptqa.collect_params('.*beta|.*gamma|.*bias').items(): |
| 190 | + v.wd_mult = 0.0 |
| 191 | + params = [ |
| 192 | + p for p in kogptqa.collect_params().values() if p.grad_req != 'null' |
| 193 | + ] |
| 194 | + # learning rate warmup |
| 195 | + accumulate = opt.accumulate |
| 196 | + step_size = batch_size * accumulate if accumulate else batch_size |
| 197 | + num_train_examples = len(train_set) |
| 198 | + num_train_steps = int(num_train_examples / step_size * num_epochs) |
| 199 | + warmup_ratio = 0.1 |
| 200 | + num_warmup_steps = int(num_train_steps * warmup_ratio) |
| 201 | + step_num = 0 |
| 202 | + all_model_params = kogptqa.collect_params() |
| 203 | + |
| 204 | + log_interval = 200 |
| 205 | + neg = -1e18 |
| 206 | + # Set grad_req if gradient accumulation is required |
| 207 | + if accumulate and accumulate > 1: |
| 208 | + for p in params: |
| 209 | + p.grad_req = 'add' |
| 210 | + |
| 211 | + for epoch_id in range(num_epochs): |
| 212 | + step_loss = 0 |
| 213 | + for batch_id, (token_ids, mask, label) in enumerate(train_dataloader): |
| 214 | + if step_num < num_warmup_steps: |
| 215 | + new_lr = lr * step_num / num_warmup_steps |
| 216 | + else: |
| 217 | + non_warmup_steps = step_num - num_warmup_steps |
| 218 | + offset = non_warmup_steps / (num_train_steps - |
| 219 | + num_warmup_steps) |
| 220 | + new_lr = lr - offset * lr |
| 221 | + trainer.set_learning_rate(new_lr) |
| 222 | + with mx.autograd.record(): |
| 223 | + # load data to GPU or GPU |
| 224 | + token_ids = token_ids.as_in_context(ctx) |
| 225 | + mask = mask.as_in_context(ctx) |
| 226 | + label = label.as_in_context(ctx) |
| 227 | + # forward computation |
| 228 | + out = kogptqa(token_ids) |
| 229 | + masked_out = nd.where( |
| 230 | + mask.expand_dims(axis=2).repeat(repeats=out.shape[2], |
| 231 | + axis=2), out, |
| 232 | + neg * nd.ones_like(out)) |
| 233 | + # loss for responses exincluding MASK and PAD |
| 234 | + ls = loss_function(masked_out, label).sum() / mask.sum() |
| 235 | + # backward computation |
| 236 | + ls.backward() |
| 237 | + if not accumulate or (batch_id + 1) % accumulate == 0: |
| 238 | + trainer.allreduce_grads() |
| 239 | + nlp.utils.clip_grad_global_norm(params, 1) |
| 240 | + trainer.update(accumulate if accumulate else 1) |
| 241 | + step_num += 1 |
| 242 | + if accumulate and accumulate > 1: |
| 243 | + # set grad to zero for gradient accumulation |
| 244 | + all_model_params.zero_grad() |
| 245 | + step_loss += ls.asscalar() |
| 246 | + if step_num % log_interval == 0 and step_num > 0: |
| 247 | + print( |
| 248 | + '[Step {} Epoch {} Batch {}/{}] loss={:.4f}, lr={:.10f}, perplexity={:.3f}' |
| 249 | + .format(step_num, epoch_id + 1, batch_id + 1, len(train_dataloader), |
| 250 | + step_loss / log_interval, trainer.learning_rate, |
| 251 | + 2 ** (step_loss / log_interval)) |
| 252 | + step_loss = 0 |
| 253 | + logging.info('saving model file to {}'.format(opt.model_params)) |
| 254 | + kogptqa.save_parameters(opt.model_params) |
| 255 | + |
| 256 | + |
| 257 | +def chat(model_params, sent='0'): |
| 258 | + tok_path = get_tokenizer() |
| 259 | + model, vocab = get_mxnet_kogpt2_model(ctx=ctx) |
| 260 | + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0) |
| 261 | + kogptqa = KoGPT2Chat(model) |
| 262 | + kogptqa.load_parameters(model_params, ctx=ctx) |
| 263 | + sent_tokens = tok(sent) |
| 264 | + while 1: |
| 265 | + q = input('user > ').strip() |
| 266 | + if q == 'quit': |
| 267 | + break |
| 268 | + q_tok = tok(q) |
| 269 | + a = '' |
| 270 | + a_tok = [] |
| 271 | + while 1: |
| 272 | + input_ids = mx.nd.array([vocab[U_TKN]] + vocab[q_tok] + |
| 273 | + vocab[EOS, SENT] + vocab[sent_tokens] + |
| 274 | + vocab[EOS, S_TKN] + |
| 275 | + vocab[a_tok]).expand_dims(axis=0) |
| 276 | + pred = kogptqa(input_ids.as_in_context(ctx)) |
| 277 | + gen = vocab.to_tokens( |
| 278 | + mx.nd.argmax( |
| 279 | + pred, |
| 280 | + axis=-1).squeeze().astype('int').asnumpy().tolist())[-1] |
| 281 | + if gen == EOS: |
| 282 | + break |
| 283 | + a += gen.replace('โ', ' ') |
| 284 | + a_tok = tok(a) |
| 285 | + print("Simsimi > {}".format(a.strip())) |
| 286 | + |
| 287 | + |
| 288 | +if __name__ == "__main__": |
| 289 | + if opt.train: |
| 290 | + train() |
| 291 | + if opt.chat: |
| 292 | + chat(opt.model_params, opt.sentiment) |
0 commit comments