Skip to content

Commit

Permalink
#52: 完成 metric 测试
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 9, 2021
1 parent 4f4e484 commit 68f374e
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 17 deletions.
2 changes: 1 addition & 1 deletion mrc/label_decoder/mrc_label_index_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __call__(self,
match_logits: torch.Tensor,
mask: torch.BoolTensor) -> torch.LongTensor:

batch_size, seq_len, _ = start_logits.size()
batch_size, seq_len = start_logits.size()

# match label pred, [batch_size, seq_len, seq_len]
match_preds = match_logits > 0
Expand Down
17 changes: 10 additions & 7 deletions mrc/metric/mrc_f1_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ def __call__(self, prediction_match_labels: torch.Tensor, gold_match_labels: tor
prediction_match_labels = prediction_match_labels & match_label_mask
gold_match_labels = gold_match_labels & match_label_mask

true_positives = (gold_match_labels & prediction_match_labels).long().sum()
false_positives = (~gold_match_labels & prediction_match_labels).long().sum()
false_negatives = (gold_match_labels & ~prediction_match_labels).long().sum()

self._true_positives[MRCF1Metric.All] += true_positives
self._false_positives[MRCF1Metric.All] += false_positives
self._false_negatives[MRCF1Metric.All] += false_negatives
true_positive_value = (gold_match_labels & prediction_match_labels).long().sum()
true_positives = {MRCF1Metric.All: true_positive_value}
false_positive_value = (~gold_match_labels & prediction_match_labels).long().sum()
false_positives = {MRCF1Metric.All: false_positive_value}
false_negative_value = (gold_match_labels & ~prediction_match_labels).long().sum()
false_negatives = {MRCF1Metric.All: false_negative_value}

self._true_positives[MRCF1Metric.All] += true_positive_value
self._false_positives[MRCF1Metric.All] += false_positive_value
self._false_negatives[MRCF1Metric.All] += false_negative_value

return self._metric(true_positives=true_positives,
false_positives=false_positives,
Expand Down
6 changes: 3 additions & 3 deletions mrc/metric/mrc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MrcModelMetricAdapter(ModelMetricAdapter):

def __init__(self):
self.model_label_decoder = MRCModelLabelDecoder()
self.mrc_f1_metric = MRCF1Metric()
self.mrc_f1_metric = MRCF1Metric(labels=list())

def __call__(self, model_outputs: MRCNerOutput, golden_label_dict: Dict[str, Tensor]) -> Tuple[Dict, ModelTargetMetric]:
"""
Expand All @@ -47,13 +47,13 @@ def __call__(self, model_outputs: MRCNerOutput, golden_label_dict: Dict[str, Ten

match_prediction_labels = self.model_label_decoder.decode_label_index(model_outputs=model_outputs)

match_golend_labels = golden_label_dict["match_position_labels"]
match_golden_labels = golden_label_dict["match_position_labels"]

# 计算 overall f1
mask = model_outputs.mask.detach()

metric_dict = self.mrc_f1_metric(prediction_match_labels=match_prediction_labels,
gold_match_labels=match_golend_labels,
gold_match_labels=match_golden_labels,
mask=mask)

target_metric = ModelTargetMetric(metric_name=MRCF1Metric.F1_OVERALL,
Expand Down
17 changes: 13 additions & 4 deletions mrc/models/mrc_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,35 @@ def forward(self,
sequence_length = sequence_output.size(1)

start_logits = self.start_classifier(sequence_output)
# 最后一个维度去掉
# 最后一个维度去掉 (B, seq_len)
start_logits = start_logits.squeeze(-1)

assert len(start_logits.size()) == 2

end_logits = self.end_classifier(sequence_output)

# (B, seq_len)
end_logits = end_logits.squeeze(-1)

assert len(end_logits.size()) == 2

# 将每一个 i 与 j 连接在一起, 所以是 N*N的拼接,使用了 expand, 进行 两个方向的扩展
# 产生一个 match matrix
# 对于每一个 i 都与 j concat 在一起
# [batch, seq_len, seq_len, hidden]
# [B, seq_len, seq_len, hidden]
start_extend = sequence_output.unsqueeze(2).expand(-1, -1, sequence_length, -1)

# [batch, seq_len, seq_len, hidden]
# [B, seq_len, seq_len, hidden]
end_extend = sequence_output.unsqueeze(1).expand(-1, sequence_length, -1, -1)

# [batch, seq_len, seq_len, hidden*2]
# [B, seq_len, seq_len, hidden*2]
match_matrix = torch.cat([start_extend, end_extend], 3)

# (B, seq_len, seq_len)
match_logits = self.match_classifier(match_matrix).squeeze(-1)

assert len(match_logits.size()) == 3

return MRCNerOutput(start_logits=start_logits,
end_logits=end_logits,
match_logits=match_logits,
Expand Down
12 changes: 12 additions & 0 deletions mrc/tests/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
brief
Authors: PanXu
Date: 2021/11/09 08:17:00
"""
111 changes: 111 additions & 0 deletions mrc/tests/metric/test_mrc_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python 3
# -*- coding: utf-8 -*-

#
# Copyright (c) 2021 PanXu, Inc. All Rights Reserved
#
"""
测试 mrc metric
Authors: PanXu
Date: 2021/11/09 08:26:00
"""
import logging
import torch

from easytext.utils.json_util import json2str

from mrc.models import MRCNerOutput
from mrc.metric import MrcModelMetricAdapter
from mrc.metric import MRCF1Metric

from mrc.tests import ASSERT


def test_mrc_metric():
start_logits = torch.tensor([[1, 1, -1, 1, -1, 1],
[1, -1, -1, -1, 1, -1]])

end_logits = torch.tensor([[1, 1, 1, 1, -1, 1],
[1, 1, 1, -1, 1, 1]])

# (1, 1), (1, 2), (1, 3), (1, 5)
# (3, 3), (3, 5)
# (5, 5)
################################
# (4, 4), (4, 5)
match_logits = torch.tensor([
[
[1, 1, 1, 1, 1, -1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, -1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]
],
[
[1, -1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, -1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]
]
])

mask = torch.tensor([
[False, True, True, True, True, True],
[False, True, True, True, True, True]
])

model_outputs = MRCNerOutput(start_logits=start_logits,
end_logits=end_logits,
match_logits=match_logits,
mask=mask)

# (1, 1), (1, 2), (1, 3), (1, 5)
# (2, 2), (2, 5)
# (3, 3), (3, 4), (3, 5)
# (4, 4), (4, 5)
# (5, 5)
###################
# (1, 1), (1, 3), (1, 4)
# (2, 2), (2, 3,) (2, 4) (2, 5)
# (3, 3), (3, 5)
# (4, 4), (4, 5)
# (5, 5)
golden_match_logits = torch.tensor([
[
[1, 0, 1, 0, 1, 0],
[1, 1, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]
],
[
[1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]
]
])


golden_label_dict = {"match_position_labels": golden_match_logits}
mrc_metric = MrcModelMetricAdapter()

metric_dict, target_metric = mrc_metric(model_outputs=model_outputs, golden_label_dict=golden_label_dict)

logging.info(f"metric dict: {json2str(metric_dict)}\ntarget metric: {json2str(target_metric)}")

expect_precision = 9/9
expect_recall = 9/24
ASSERT.assertAlmostEqual(expect_precision, metric_dict[MRCF1Metric.PRECISION_OVERALL])
ASSERT.assertAlmostEqual(expect_recall, metric_dict[MRCF1Metric.RECALL_OVERALL])

ASSERT.assertEqual(MRCF1Metric.F1_OVERALL, target_metric.name)
ASSERT.assertAlmostEqual(metric_dict[MRCF1Metric.F1_OVERALL], target_metric.value)

4 changes: 2 additions & 2 deletions mrc/tests/paper/bert_query_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None):
batch_size, seq_len, hid_size = sequence_heatmap.size()

# 得到 start logits
start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len]

# 得到 end logits
end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]
end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len]

# 将每一个 i 与 j 连接在一起, 所以是 N*N的拼接,使用了 expand, 进行 两个方向的扩展
# start 和 end 的 logits 有必要存在吗???? 还是说,增加了一些约束?
Expand Down

0 comments on commit 68f374e

Please sign in to comment.