Skip to content

Commit 5e9897b

Browse files
committed
#52: debug trainer finished
1 parent 3dc0242 commit 5e9897b

File tree

7 files changed

+20
-16
lines changed

7 files changed

+20
-16
lines changed

data/mrc_ner/config/config.json

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"__name_space__": "mrc_ner",
1717
"tokenizer": {
1818
"__type__": "BertTokenizer",
19+
"__name_space__": "__easytext__",
1920
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch"
2021
},
2122

@@ -25,7 +26,7 @@
2526

2627

2728
"model": {
28-
"__type__": "BertRnnWithCrf",
29+
"__type__": "MRCNer",
2930
"__name_space__": "mrc_ner",
3031
"bert_dir": "/Users/panxu/MyProjects/github/easytext/data/pretrained/bert/chinese_roberta_wwm_large_ext_pytorch",
3132
"dropout": 0.1
@@ -52,13 +53,15 @@
5253
"weight_decay": 0.01
5354
},
5455

55-
"lr_scheduler": {
56+
"#lr_scheduler": {
5657
"__type__": "MRCLrScheduler",
5758
"__name_space__": "mrc_ner",
5859
"max_lr": 0.00002,
5960
"final_div_factor": 10000,
60-
"total_steps": null,
61+
"total_steps": null
6162
},
63+
64+
"lr_scheduler": null,
6265
"grad_rescaled": null,
6366

6467
"process_group_parameter": {

mrc/data/bert_model_collate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
8989
metadata["positions"] = zip(start_positions, end_positions)
9090

9191
start_positions = [(query_offset + start_position) for start_position in start_positions]
92-
start_position_labels = torch.zeros(batch_max_length)
92+
start_position_labels = torch.zeros(batch_max_length, dtype=torch.long)
9393

9494
for start_position in start_positions:
9595
if start_position < batch_max_length - 1:
@@ -98,17 +98,17 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
9898
batch_start_position_labels.append(start_position_labels)
9999

100100
end_positions = [(query_offset + end_position) for end_position in end_positions]
101-
end_position_labels = torch.zeros(batch_max_length)
101+
end_position_labels = torch.zeros(batch_max_length, dtype=torch.long)
102102

103103
for end_position in end_positions:
104104

105105
if end_position < batch_max_length - 1:
106106
end_position_labels[end_position] = 1
107107

108-
batch_end_position_labels.append(torch.tensor(end_position_labels, dtype=torch.long))
108+
batch_end_position_labels.append(end_position_labels)
109109

110110
# match position
111-
match_positions = torch.zeros(size=(batch_max_length, batch_max_length))
111+
match_positions = torch.zeros(size=(batch_max_length, batch_max_length), dtype=torch.long)
112112

113113
for start_position, end_position in zip(start_positions, end_positions):
114114

mrc/label_decoder/mrc_label_index_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __call__(self,
2828
match_logits: torch.Tensor,
2929
mask: torch.BoolTensor) -> torch.LongTensor:
3030

31+
mask = mask.bool()
3132
batch_size, seq_len = start_logits.size()
3233

3334
# match label pred, [batch_size, seq_len, seq_len]

mrc/loss/mrc_bce_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def __call__(self, model_outputs: MRCNerOutput, golden_label: Dict[str, torch.Te
4040

4141
batch_size, sequence_length = model_outputs.start_logits.size()
4242

43-
start_loss = self.loss(model_outputs.start_logits, golden_label["start_position_labels"])
43+
start_loss = self.loss(model_outputs.start_logits, golden_label["start_position_labels"].float())
4444
# 计算得到 mean
4545
start_loss = (start_loss * mask).sum() / mask.sum()
4646

47-
end_loss = self.loss(model_outputs.end_logits, golden_label["end_position_labels"])
47+
end_loss = self.loss(model_outputs.end_logits, golden_label["end_position_labels"].float())
4848
end_loss = (end_loss * mask).sum() / mask.sum()
4949

5050
match_loss = self.loss(model_outputs.match_logits.view(batch_size, -1),
51-
golden_label["match_position_labels"].view(batch_size, -1))
51+
golden_label["match_position_labels"].float().view(batch_size, -1))
5252

5353
match_label_row_mask = mask.bool().unsqueeze(-1).expand(-1, -1, sequence_length)
5454
match_label_col_mask = mask.bool().unsqueeze(-2).expand(-1, sequence_length, -1)

mrc/metric/mrc_f1_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __call__(self, prediction_match_labels: torch.Tensor, gold_match_labels: tor
4343
:param mask: mask
4444
:return: metric dict
4545
"""
46-
46+
mask = mask.bool()
4747
batch_size, seq_length = mask.size()
4848

4949
match_label_mask = (mask.unsqueeze(-1).expand(-1, -1, seq_length)

mrc/metric/mrc_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ def __init__(self):
3636
self.model_label_decoder = MRCModelLabelDecoder()
3737
self.mrc_f1_metric = MRCF1Metric(labels=list())
3838

39-
def __call__(self, model_outputs: MRCNerOutput, golden_label_dict: Dict[str, Tensor]) -> Tuple[Dict, ModelTargetMetric]:
39+
def __call__(self, model_outputs: MRCNerOutput, golden_labels: Dict[str, Tensor]) -> Tuple[Dict, ModelTargetMetric]:
4040
"""
4141
计算 metric
4242
:param model_outputs:
43-
:param golden_label_dict: start_position_labels, end_position_labels, batch_match_positions
43+
:param golden_labels: start_position_labels, end_position_labels, batch_match_positions
4444
:return:
4545
"""
4646
model_outputs: MRCNerOutput = model_outputs
4747

4848
match_prediction_labels = self.model_label_decoder.decode_label_index(model_outputs=model_outputs)
4949

50-
match_golden_labels = golden_label_dict["match_position_labels"]
50+
match_golden_labels = golden_labels["match_position_labels"]
5151

5252
# 计算 overall f1
5353
mask = model_outputs.mask.detach()

mrc/optimizer/mrc_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def create(self, model: Model) -> "Optimizer":
2929
no_decay = ["bias", "LayerNorm.weight"]
3030
parameters = [
3131
{
32-
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
32+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
3333
"weight_decay": self.weight_decay,
3434
},
3535
{
36-
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
36+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
3737
"weight_decay": 0.0,
3838
}
3939
]

0 commit comments

Comments
 (0)