@@ -129,7 +129,6 @@ def collate_fn(batch):
129
129
training_stats = json .load (file )
130
130
best_result = training_stats [- 1 ]['Best result' ]
131
131
num_step = training_stats [- 1 ]['Step' ]
132
- best_checkpoint = model
133
132
134
133
for epoch_i in range (0 , args .num_epochs ):
135
134
print ("" )
@@ -228,7 +227,7 @@ def collate_fn(batch):
228
227
teacher_logits = teacher_logits
229
228
)
230
229
231
- generated_lm_texts = model .generate_expl (
230
+ gen_preds = model .generate_expl (
232
231
input_ids = b_input_ids ,
233
232
attention_mask = b_attention_mask ,
234
233
num_beams = 4 ,
@@ -240,61 +239,56 @@ def collate_fn(batch):
240
239
cls_preds = cls_outputs .argmax (dim = - 1 ).cpu ().numpy ()
241
240
cls_labels = b_teacher_cls .cpu ().numpy ()
242
241
243
- generated_lm_texts = decoder_tokenizer .batch_decode (generated_lm_texts , skip_special_tokens = True )
244
- b_lm_labels = b_lm_labels .detach ().clone ()
245
- b_lm_labels [ b_lm_labels == - 100 ] = decoder_tokenizer .pad_token_id
246
- b_lm_labels = decoder_tokenizer .batch_decode (b_lm_labels , skip_special_tokens = True )
242
+ gen_preds = decoder_tokenizer .batch_decode (gen_preds , skip_special_tokens = True )
243
+ gen_labels = b_lm_labels .detach ().clone ()
244
+ gen_labels [ gen_labels == - 100 ] = decoder_tokenizer .pad_token_id
245
+ gen_labels = decoder_tokenizer .batch_decode (gen_labels , skip_special_tokens = True )
247
246
248
247
# Conditional Decoding Constraint
249
- generated_lm_texts = ["none" if cls_pred == "normal" else expl for cls_pred , expl in zip (cls_preds , generated_lm_texts )]
248
+ gen_preds = ["none" if cls_pred == 0 else expl for cls_pred , expl in zip (cls_preds , gen_preds )]
250
249
251
250
eval_cls_preds .extend (cls_preds )
252
251
eval_cls_labels .extend (cls_labels )
253
- eval_gen_preds .extend (generated_lm_texts )
254
- eval_gen_labels .extend (b_lm_labels )
252
+ eval_gen_preds .extend ([ text . split ( 'SEP>' )[ - 1 ]. strip () for text in gen_preds ] )
253
+ eval_gen_labels .extend ([ text . split ( 'SEP>' )[ - 1 ]. strip () for text in gen_labels ] )
255
254
256
255
total_eval_lm_loss += lm_loss .item ()
257
256
total_eval_cls_loss += cls_loss .item ()
258
257
total_eval_kl_loss += kl_loss .item ()
259
258
260
- print ("generated_lm_texts: " , generated_lm_texts )
261
- print ("groundtruth_lm_texts: " , b_lm_labels )
262
- generation_scores = compute_generation_scores (generated_lm_texts , b_lm_labels )
263
- bleu = generation_scores [0 ]
264
- rouge = generation_scores [1 ]
265
- meteor = generation_scores [2 ]
266
- bertscore = generation_scores [3 ]
267
- acc , f1 = compute_classification_scores (cls_labels , cls_preds )
259
+ print ("gen_preds: " , gen_preds )
260
+ print ("gen_labels: " , gen_labels )
261
+ acc , f1 = compute_classification_scores (eval_cls_labels , eval_cls_preds )
262
+ bleu , rouge , meteor , bertscore = compute_generation_scores (eval_gen_labels , eval_gen_preds )
268
263
269
264
print ()
270
265
print (" Average valid LM: {0:.4f}" .format (total_eval_lm_loss / len (validation_dataloader )))
271
266
print (" Average valid CLS: {0:.4f}" .format (total_eval_cls_loss / len (validation_dataloader )))
272
267
print (" Average valid KL: {0:.4f}" .format (total_eval_kl_loss / len (validation_dataloader )))
273
268
274
- print (f"Epoch { epoch_i + 1 } step { num_step } generation evaluations: BLEU-4: { bleu } , ROUGE-L: { rouge } , METEOR: { meteor } , BERTSCORE: { bertscore } " )
275
269
print (f"Epoch { epoch_i + 1 } step { num_step } classification evaluations: Acc: { acc } , F1: { f1 } " )
270
+ print (f"Epoch { epoch_i + 1 } step { num_step } generation evaluations: BLEU-4: { bleu } , ROUGE-L: { rouge } , METEOR: { meteor } , BERTSCORE: { bertscore } " )
276
271
277
272
# Record all statistics at this epoch
278
273
training_stats .append ({
279
274
'Step' : num_step ,
280
275
'Best result' : best_result ,
281
276
'Avg train LM loss' : total_train_lm_loss / (num_step + 1 ),
282
277
'Avg train CLS loss' : total_train_cls_loss / (num_step + 1 ),
283
- 'Avg valid KL loss' : total_train_kl_loss / (num_step + 1 ),
278
+ 'Avg train KL loss' : total_train_kl_loss / (num_step + 1 ),
284
279
'Avg valid LM loss' : total_eval_lm_loss / len (validation_dataloader ),
285
280
'Avg valid CLS loss' : total_eval_cls_loss / len (validation_dataloader ),
286
281
'Avg valid KL loss' : total_eval_kl_loss / len (validation_dataloader ),
287
- 'Generation evaluation' : f"BLEU-4: { bleu } , ROUGE-L: { rouge } , METEOR: { meteor } , BERTSCORE: { bertscore } " ,
288
282
'CLS evaluation' : f"Acc: { acc } , F1: { f1 } " ,
283
+ 'Generation evaluation' : f"BLEU-4: { bleu } , ROUGE-L: { rouge } , METEOR: { meteor } , BERTSCORE: { bertscore } " ,
289
284
})
290
285
291
286
if total_eval_lm_loss / len (validation_dataloader ) < best_result :
292
287
print (f"New best checkpoint at Epoch { epoch_i + 1 } , Train_step { num_step } " )
293
288
best_result = total_eval_lm_loss / len (validation_dataloader )
294
289
training_stats [- 1 ]["Best result" ] = best_result
295
290
296
- best_checkpoint = model
297
- best_checkpoint .save_checkpoint (os .path .join (args .output_dir , "best_ckpt" ), is_best = True )
291
+ model .save_checkpoint (os .path .join (args .output_dir , "best_ckpt" ), is_best = True )
298
292
model .save_checkpoint (args .output_dir , is_best = False ,
299
293
optimizer = optimizer , scheduler = scheduler , training_stats = training_stats )
300
294
print ("Successfully saved checkpoint." )
@@ -321,7 +315,7 @@ def collate_fn(batch):
321
315
parser .add_argument ('--num_epochs' , type = int , default = 10 , help = "Number of epochs" )
322
316
parser .add_argument ("--train_batch_size" , type = int , default = 16 , help = "Training batch size" )
323
317
parser .add_argument ("--valid_batch_size" , type = int , default = 32 , help = "Validation batch size" )
324
- parser .add_argument ("--learning_rate" , type = float , default = 1e -5 , help = "Learning rate" )
318
+ parser .add_argument ("--learning_rate" , type = float , default = 2e -5 , help = "Learning rate" )
325
319
parser .add_argument ('--warmup_steps' , type = int , default = 100 )
326
320
parser .add_argument ('--max_length' , type = int , default = 256 )
327
321
parser .add_argument ("--logging_steps" , type = int , default = 500 , help = "Number of steps between logging" )
0 commit comments