Skip to content

Commit

Permalink
#52: debug trainer finished
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 11, 2021
1 parent 3dc0242 commit 5e9897b
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 16 deletions.
9 changes: 6 additions & 3 deletions data/mrc_ner/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"__name_space__": "mrc_ner",
"tokenizer": {
"__type__": "BertTokenizer",
"__name_space__": "__easytext__",
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch"
},

Expand All @@ -25,7 +26,7 @@


"model": {
"__type__": "BertRnnWithCrf",
"__type__": "MRCNer",
"__name_space__": "mrc_ner",
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch",
"dropout": 0.1
Expand All @@ -52,13 +53,15 @@
"weight_decay": 0.01
},

"lr_scheduler": {
"#lr_scheduler": {
"__type__": "MRCLrScheduler",
"__name_space__": "mrc_ner",
"max_lr": 0.00002,
"final_div_factor": 10000,
"total_steps": null,
"total_steps": null
},

"lr_scheduler": null,
"grad_rescaled": null,

"process_group_parameter": {
Expand Down
8 changes: 4 additions & 4 deletions mrc/data/bert_model_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
metadata["positions"] = zip(start_positions, end_positions)

start_positions = [(query_offset + start_position) for start_position in start_positions]
start_position_labels = torch.zeros(batch_max_length)
start_position_labels = torch.zeros(batch_max_length, dtype=torch.long)

for start_position in start_positions:
if start_position < batch_max_length - 1:
Expand All @@ -98,17 +98,17 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
batch_start_position_labels.append(start_position_labels)

end_positions = [(query_offset + end_position) for end_position in end_positions]
end_position_labels = torch.zeros(batch_max_length)
end_position_labels = torch.zeros(batch_max_length, dtype=torch.long)

for end_position in end_positions:

if end_position < batch_max_length - 1:
end_position_labels[end_position] = 1

batch_end_position_labels.append(torch.tensor(end_position_labels, dtype=torch.long))
batch_end_position_labels.append(end_position_labels)

# match position
match_positions = torch.zeros(size=(batch_max_length, batch_max_length))
match_positions = torch.zeros(size=(batch_max_length, batch_max_length), dtype=torch.long)

for start_position, end_position in zip(start_positions, end_positions):

Expand Down
1 change: 1 addition & 0 deletions mrc/label_decoder/mrc_label_index_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __call__(self,
match_logits: torch.Tensor,
mask: torch.BoolTensor) -> torch.LongTensor:

mask = mask.bool()
batch_size, seq_len = start_logits.size()

# match label pred, [batch_size, seq_len, seq_len]
Expand Down
6 changes: 3 additions & 3 deletions mrc/loss/mrc_bce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def __call__(self, model_outputs: MRCNerOutput, golden_label: Dict[str, torch.Te

batch_size, sequence_length = model_outputs.start_logits.size()

start_loss = self.loss(model_outputs.start_logits, golden_label["start_position_labels"])
start_loss = self.loss(model_outputs.start_logits, golden_label["start_position_labels"].float())
# 计算得到 mean
start_loss = (start_loss * mask).sum() / mask.sum()

end_loss = self.loss(model_outputs.end_logits, golden_label["end_position_labels"])
end_loss = self.loss(model_outputs.end_logits, golden_label["end_position_labels"].float())
end_loss = (end_loss * mask).sum() / mask.sum()

match_loss = self.loss(model_outputs.match_logits.view(batch_size, -1),
golden_label["match_position_labels"].view(batch_size, -1))
golden_label["match_position_labels"].float().view(batch_size, -1))

match_label_row_mask = mask.bool().unsqueeze(-1).expand(-1, -1, sequence_length)
match_label_col_mask = mask.bool().unsqueeze(-2).expand(-1, sequence_length, -1)
Expand Down
2 changes: 1 addition & 1 deletion mrc/metric/mrc_f1_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, prediction_match_labels: torch.Tensor, gold_match_labels: tor
:param mask: mask
:return: metric dict
"""

mask = mask.bool()
batch_size, seq_length = mask.size()

match_label_mask = (mask.unsqueeze(-1).expand(-1, -1, seq_length)
Expand Down
6 changes: 3 additions & 3 deletions mrc/metric/mrc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ def __init__(self):
self.model_label_decoder = MRCModelLabelDecoder()
self.mrc_f1_metric = MRCF1Metric(labels=list())

def __call__(self, model_outputs: MRCNerOutput, golden_label_dict: Dict[str, Tensor]) -> Tuple[Dict, ModelTargetMetric]:
def __call__(self, model_outputs: MRCNerOutput, golden_labels: Dict[str, Tensor]) -> Tuple[Dict, ModelTargetMetric]:
"""
计算 metric
:param model_outputs:
:param golden_label_dict: start_position_labels, end_position_labels, batch_match_positions
:param golden_labels: start_position_labels, end_position_labels, batch_match_positions
:return:
"""
model_outputs: MRCNerOutput = model_outputs

match_prediction_labels = self.model_label_decoder.decode_label_index(model_outputs=model_outputs)

match_golden_labels = golden_label_dict["match_position_labels"]
match_golden_labels = golden_labels["match_position_labels"]

# 计算 overall f1
mask = model_outputs.mask.detach()
Expand Down
4 changes: 2 additions & 2 deletions mrc/optimizer/mrc_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def create(self, model: Model) -> "Optimizer":
no_decay = ["bias", "LayerNorm.weight"]
parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
}
]
Expand Down

0 comments on commit 5e9897b

Please sign in to comment.