Skip to content

Fixes in CRF #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformerkp/tagger/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class KETrainingArguments(TrainingArguments):
A custom training argument class for keyphrase extraction, extending HF's TrainingArguments class.
"""

score_aggregation_method: bool = field(
score_aggregation_method: Optional[str] = field(
default="avg",
metadata={
"help": "which method among avg, max and first to use while calculating confidence score of a keyphrase. "
Expand Down
9 changes: 6 additions & 3 deletions src/transformerkp/tagger/models/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,19 @@ def forward(
if mask is None:
mask = torch.ones(*tags.size(), dtype=torch.bool)
else:
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
# tokens tagged with -100 will have zero attention
mask[tags == -100] = 0
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
mask = mask.to(torch.bool)
is_masked = tags == -100
tags[is_masked] = self.label2id[self.id2label[0]]
# to make sure label of tagged token is always "O".
tags[is_masked] = self.label2id["O"]

log_denominator = self._input_likelihood(inputs, mask)

log_numerator = self._joint_likelihood(inputs, tags, mask)
# tags[is_masked] = -100
# revert the tags for fair evaluation of entity
tags[is_masked] = -100
return torch.sum(log_numerator - log_denominator)

def viterbi_tags(
Expand Down
36 changes: 18 additions & 18 deletions src/transformerkp/tagger/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,25 @@ def forward(

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss = -1.0 * self.crf(logits, labels.clone(), attention_mask)
best_path = self.crf.viterbi_tags(logits, mask=attention_mask)
best_path = self.crf.viterbi_tags(logits=logits, mask=attention_mask)
# ignore score of path, just store the tags value
best_path = [x for x, _ in best_path]
logits *= 0.0
class_prob = logits * 0.0
for i, path in enumerate(best_path):
for j, tag in enumerate(path):
# j+ 1 to ignore clf token at begning
logits[i, j + 1 - attention_mask[0], int(tag)] = 1.0
class_prob[i, j, int(tag)] = 1.0

loss = None
if labels is not None:
loss = -1.0 * self.crf(logits, labels, attention_mask)

if not return_dict:
output = (logits,) + outputs[2:]
output = (class_prob,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
logits=class_prob,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down Expand Up @@ -147,25 +147,25 @@ def forward(
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss = -1.0 * self.crf(logits, labels.clone(), attention_mask)
best_path = self.crf.viterbi_tags(logits, mask=attention_mask)
best_path = self.crf.viterbi_tags(logits=logits, mask=attention_mask)
# ignore score of path, just store the tags value
best_path = [x for x, _ in best_path]
logits *= 0.0
class_prob = logits * 0.0
for i, path in enumerate(best_path):
for j, tag in enumerate(path):
# j+ 1 to ignore clf token at begning
logits[i, j + 1 - attention_mask[0], int(tag)] = 1.0
class_prob[i, j, int(tag)] = 1.0

loss = None
if labels is not None:
loss = -1.0 * self.crf(logits, labels, attention_mask)

if not return_dict:
output = (logits,) + outputs[2:]
output = (class_prob,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
logits=class_prob,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
20 changes: 9 additions & 11 deletions src/transformerkp/tagger/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,20 @@ def __init__(
self.use_crf = (
self.config.use_crf if hasattr(self.config, "use_crf") else use_crf
)
self.config.use_crf = self.use_crf
self.config.label2id = LABELS_TO_ID
self.config.id2label = ID_TO_LABELS

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name if tokenizer_name else model_name_or_path,
use_fast=True,
add_prefix_space=True,
)
# set pad token if none
pad_token_none = self.tokenizer.pad_token == None
if pad_token_none:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.trainer = trainer
self.data_collator = data_collator
self.model_type = (
Expand Down Expand Up @@ -132,17 +140,7 @@ def train(
# Set seed before initializing model.
set_seed(training_args.seed)

# set pad token if none
pad_token_none = self.tokenizer.pad_token == None
if pad_token_none:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.config.use_crf = self.use_crf
self.config.label2id = LABELS_TO_ID
self.config.id2label = ID_TO_LABELS
if pad_token_none:
self.config.pad_token_id = self.config.eos_token_id
# initialize data collator
# set data collator
data_collator = (
self.data_collator
if self.data_collator
Expand Down