diff --git a/src/transformerkp/tagger/args.py b/src/transformerkp/tagger/args.py index c4df67d..0a96844 100644 --- a/src/transformerkp/tagger/args.py +++ b/src/transformerkp/tagger/args.py @@ -10,7 +10,7 @@ class KETrainingArguments(TrainingArguments): A custom training argument class for keyphrase extraction, extending HF's TrainingArguments class. """ - score_aggregation_method: bool = field( + score_aggregation_method: Optional[str] = field( default="avg", metadata={ "help": "which method among avg, max and first to use while calculating confidence score of a keyphrase. " diff --git a/src/transformerkp/tagger/models/crf.py b/src/transformerkp/tagger/models/crf.py index c6f8bbd..11da7ad 100644 --- a/src/transformerkp/tagger/models/crf.py +++ b/src/transformerkp/tagger/models/crf.py @@ -181,16 +181,19 @@ def forward( if mask is None: mask = torch.ones(*tags.size(), dtype=torch.bool) else: - # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + # tokens tagged with -100 will have zero attention mask[tags == -100] = 0 + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. mask = mask.to(torch.bool) is_masked = tags == -100 - tags[is_masked] = self.label2id[self.id2label[0]] + # to make sure label of tagged token is always "O". + tags[is_masked] = self.label2id["O"] log_denominator = self._input_likelihood(inputs, mask) log_numerator = self._joint_likelihood(inputs, tags, mask) - # tags[is_masked] = -100 + # revert the tags for fair evaluation of entity + tags[is_masked] = -100 return torch.sum(log_numerator - log_denominator) def viterbi_tags( diff --git a/src/transformerkp/tagger/models/models.py b/src/transformerkp/tagger/models/models.py index cc61f2a..f23c518 100644 --- a/src/transformerkp/tagger/models/models.py +++ b/src/transformerkp/tagger/models/models.py @@ -66,25 +66,25 @@ def forward( sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - loss = None - if labels is not None: - loss = -1.0 * self.crf(logits, labels.clone(), attention_mask) - best_path = self.crf.viterbi_tags(logits, mask=attention_mask) + best_path = self.crf.viterbi_tags(logits=logits, mask=attention_mask) # ignore score of path, just store the tags value best_path = [x for x, _ in best_path] - logits *= 0.0 + class_prob = logits * 0.0 for i, path in enumerate(best_path): for j, tag in enumerate(path): - # j+ 1 to ignore clf token at begning - logits[i, j + 1 - attention_mask[0], int(tag)] = 1.0 + class_prob[i, j, int(tag)] = 1.0 + + loss = None + if labels is not None: + loss = -1.0 * self.crf(logits, labels, attention_mask) if not return_dict: - output = (logits,) + outputs[2:] + output = (class_prob,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, - logits=logits, + logits=class_prob, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @@ -147,25 +147,25 @@ def forward( sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - loss = None - if labels is not None: - loss = -1.0 * self.crf(logits, labels.clone(), attention_mask) - best_path = self.crf.viterbi_tags(logits, mask=attention_mask) + best_path = self.crf.viterbi_tags(logits=logits, mask=attention_mask) # ignore score of path, just store the tags value best_path = [x for x, _ in best_path] - logits *= 0.0 + class_prob = logits * 0.0 for i, path in enumerate(best_path): for j, tag in enumerate(path): - # j+ 1 to ignore clf token at begning - logits[i, j + 1 - attention_mask[0], int(tag)] = 1.0 + class_prob[i, j, int(tag)] = 1.0 + + loss = None + if labels is not None: + loss = -1.0 * self.crf(logits, labels, attention_mask) if not return_dict: - output = (logits,) + outputs[2:] + output = (class_prob,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, - logits=logits, + logits=class_prob, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformerkp/tagger/tagger.py b/src/transformerkp/tagger/tagger.py index 1612f77..b0061d7 100644 --- a/src/transformerkp/tagger/tagger.py +++ b/src/transformerkp/tagger/tagger.py @@ -40,12 +40,20 @@ def __init__( self.use_crf = ( self.config.use_crf if hasattr(self.config, "use_crf") else use_crf ) + self.config.use_crf = self.use_crf + self.config.label2id = LABELS_TO_ID + self.config.id2label = ID_TO_LABELS self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name if tokenizer_name else model_name_or_path, use_fast=True, add_prefix_space=True, ) + # set pad token if none + pad_token_none = self.tokenizer.pad_token == None + if pad_token_none: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.trainer = trainer self.data_collator = data_collator self.model_type = ( @@ -132,17 +140,7 @@ def train( # Set seed before initializing model. set_seed(training_args.seed) - # set pad token if none - pad_token_none = self.tokenizer.pad_token == None - if pad_token_none: - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.config.use_crf = self.use_crf - self.config.label2id = LABELS_TO_ID - self.config.id2label = ID_TO_LABELS - if pad_token_none: - self.config.pad_token_id = self.config.eos_token_id - # initialize data collator + # set data collator data_collator = ( self.data_collator if self.data_collator