Skip to content

Commit 0b378b5

Browse files
committed
fix typo
1 parent e06c959 commit 0b378b5

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

toxcl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, decoder_model, decoder_tokenizer=None, tg_model=None, tg_toke
2222
self.tg_tokenizer = tg_tokenizer
2323

2424
self.num_labels = num_labels
25-
self.classifier = nn.Linear(hidden_size, num_labels)
25+
self.classifier = nn.Linear(hidden_size, num_labels).to(self.device)
2626
self.loss_fct = nn.BCELoss()
2727
self.kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
2828
self.activation = nn.Softmax(dim=-1)
@@ -32,7 +32,7 @@ def classify(self, input_ids, attention_mask=None):
3232
last_hidden_state = outputs.encoder_last_hidden_state
3333

3434
cls_token_emb = torch.mean(last_hidden_state, dim=1).squeeze()
35-
logits = self.classifier(cls_token_emb).squeeze().to(self.device)
35+
logits = self.classifier(cls_token_emb).squeeze()
3636
logits = logits.view(-1, self.num_labels)
3737
return self.activation(logits)
3838

@@ -74,7 +74,7 @@ def generate_e2e(self, prompts, apply_constraints=True, tg_generation_params=Non
7474

7575
# (3) Conditional Decoding Constraint
7676
if apply_constraints:
77-
explainations = ["none" if cls_pred == "normal" else expl for cls_pred, expl in zip(prediction_labels, decoded_explanations)]
77+
decoded_explanations = ["none" if cls_pred == "normal" else expl for cls_pred, expl in zip(prediction_labels, decoded_explanations)]
7878

7979
return dict(target_groups=decoded_tg_outputs, detections=prediction_labels, explanations=decoded_explanations)
8080

train.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def collate_fn(batch):
129129
training_stats = json.load(file)
130130
best_result = training_stats[-1]['Best result']
131131
num_step = training_stats[-1]['Step']
132-
best_checkpoint = model
133132

134133
for epoch_i in range(0, args.num_epochs):
135134
print("")
@@ -228,7 +227,7 @@ def collate_fn(batch):
228227
teacher_logits=teacher_logits
229228
)
230229

231-
generated_lm_texts = model.generate_expl(
230+
gen_preds = model.generate_expl(
232231
input_ids=b_input_ids,
233232
attention_mask=b_attention_mask,
234233
num_beams=4,
@@ -240,61 +239,56 @@ def collate_fn(batch):
240239
cls_preds = cls_outputs.argmax(dim=-1).cpu().numpy()
241240
cls_labels = b_teacher_cls.cpu().numpy()
242241

243-
generated_lm_texts = decoder_tokenizer.batch_decode(generated_lm_texts, skip_special_tokens=True)
244-
b_lm_labels = b_lm_labels.detach().clone()
245-
b_lm_labels[b_lm_labels == -100] = decoder_tokenizer.pad_token_id
246-
b_lm_labels = decoder_tokenizer.batch_decode(b_lm_labels, skip_special_tokens=True)
242+
gen_preds = decoder_tokenizer.batch_decode(gen_preds, skip_special_tokens=True)
243+
gen_labels = b_lm_labels.detach().clone()
244+
gen_labels[gen_labels == -100] = decoder_tokenizer.pad_token_id
245+
gen_labels = decoder_tokenizer.batch_decode(gen_labels, skip_special_tokens=True)
247246

248247
# Conditional Decoding Constraint
249-
generated_lm_texts = ["none" if cls_pred == "normal" else expl for cls_pred, expl in zip(cls_preds, generated_lm_texts)]
248+
gen_preds = ["none" if cls_pred == 0 else expl for cls_pred, expl in zip(cls_preds, gen_preds)]
250249

251250
eval_cls_preds.extend(cls_preds)
252251
eval_cls_labels.extend(cls_labels)
253-
eval_gen_preds.extend(generated_lm_texts)
254-
eval_gen_labels.extend(b_lm_labels)
252+
eval_gen_preds.extend([text.split('SEP>')[-1].strip() for text in gen_preds])
253+
eval_gen_labels.extend([text.split('SEP>')[-1].strip() for text in gen_labels])
255254

256255
total_eval_lm_loss += lm_loss.item()
257256
total_eval_cls_loss += cls_loss.item()
258257
total_eval_kl_loss += kl_loss.item()
259258

260-
print("generated_lm_texts: ", generated_lm_texts)
261-
print("groundtruth_lm_texts: ", b_lm_labels)
262-
generation_scores = compute_generation_scores(generated_lm_texts, b_lm_labels)
263-
bleu = generation_scores[0]
264-
rouge = generation_scores[1]
265-
meteor = generation_scores[2]
266-
bertscore = generation_scores[3]
267-
acc, f1 = compute_classification_scores(cls_labels, cls_preds)
259+
print("gen_preds: ", gen_preds)
260+
print("gen_labels: ", gen_labels)
261+
acc, f1 = compute_classification_scores(eval_cls_labels, eval_cls_preds)
262+
bleu, rouge, meteor, bertscore = compute_generation_scores(eval_gen_labels, eval_gen_preds)
268263

269264
print()
270265
print(" Average valid LM: {0:.4f}".format(total_eval_lm_loss/len(validation_dataloader)))
271266
print(" Average valid CLS: {0:.4f}".format(total_eval_cls_loss/len(validation_dataloader)))
272267
print(" Average valid KL: {0:.4f}".format(total_eval_kl_loss/len(validation_dataloader)))
273268

274-
print(f"Epoch {epoch_i + 1} step {num_step} generation evaluations: BLEU-4: {bleu}, ROUGE-L: {rouge}, METEOR: {meteor}, BERTSCORE: {bertscore}")
275269
print(f"Epoch {epoch_i + 1} step {num_step} classification evaluations: Acc: {acc}, F1: {f1}")
270+
print(f"Epoch {epoch_i + 1} step {num_step} generation evaluations: BLEU-4: {bleu}, ROUGE-L: {rouge}, METEOR: {meteor}, BERTSCORE: {bertscore}")
276271

277272
# Record all statistics at this epoch
278273
training_stats.append({
279274
'Step': num_step,
280275
'Best result': best_result,
281276
'Avg train LM loss': total_train_lm_loss / (num_step+1),
282277
'Avg train CLS loss': total_train_cls_loss / (num_step+1),
283-
'Avg valid KL loss': total_train_kl_loss / (num_step+1),
278+
'Avg train KL loss': total_train_kl_loss / (num_step+1),
284279
'Avg valid LM loss': total_eval_lm_loss/len(validation_dataloader),
285280
'Avg valid CLS loss': total_eval_cls_loss/len(validation_dataloader),
286281
'Avg valid KL loss': total_eval_kl_loss/len(validation_dataloader),
287-
'Generation evaluation': f"BLEU-4: {bleu}, ROUGE-L: {rouge}, METEOR: {meteor}, BERTSCORE: {bertscore}",
288282
'CLS evaluation': f"Acc: {acc}, F1: {f1}",
283+
'Generation evaluation': f"BLEU-4: {bleu}, ROUGE-L: {rouge}, METEOR: {meteor}, BERTSCORE: {bertscore}",
289284
})
290285

291286
if total_eval_lm_loss/len(validation_dataloader) < best_result:
292287
print(f"New best checkpoint at Epoch {epoch_i + 1}, Train_step {num_step}")
293288
best_result = total_eval_lm_loss/len(validation_dataloader)
294289
training_stats[-1]["Best result"] = best_result
295290

296-
best_checkpoint = model
297-
best_checkpoint.save_checkpoint(os.path.join(args.output_dir, "best_ckpt"), is_best=True)
291+
model.save_checkpoint(os.path.join(args.output_dir, "best_ckpt"), is_best=True)
298292
model.save_checkpoint(args.output_dir, is_best=False,
299293
optimizer=optimizer, scheduler=scheduler, training_stats=training_stats)
300294
print("Successfully saved checkpoint.")
@@ -321,7 +315,7 @@ def collate_fn(batch):
321315
parser.add_argument('--num_epochs', type=int, default=10, help="Number of epochs")
322316
parser.add_argument("--train_batch_size", type=int, default=16, help="Training batch size")
323317
parser.add_argument("--valid_batch_size", type=int, default=32, help="Validation batch size")
324-
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate")
318+
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
325319
parser.add_argument('--warmup_steps', type=int, default=100)
326320
parser.add_argument('--max_length', type=int, default=256)
327321
parser.add_argument("--logging_steps", type=int, default=500, help="Number of steps between logging")

0 commit comments

Comments
 (0)