Skip to content

Commit

Permalink
#52: 修复 bert tokenizer 引发的 offset问题
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 22, 2021
1 parent 5e9897b commit a006c77
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 39 deletions.
11 changes: 1 addition & 10 deletions data/dataset/mrc_msra_ner/sample.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
[
{
"context": "藏 书 本 来 就 是 所 有 传 统 收 藏 门 类 中 的 第 一 大 户 , 只 是 我 们 结 束 温 饱 的 时 间 太 短 而 已 。",
"end_position": [],
"entity_label": "NT",
"impossible": true,
"qas_id": "1.3",
"query": "组织包括公司,政府党派,学校,政府,新闻机构",
"span_position": [],
"start_position": []
},

{
"context": "因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一 。",
"end_position": [
Expand Down
6 changes: 3 additions & 3 deletions data/mrc_ner/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
"training_dataset": {
"__type__": "MSRAFlatNerDataset",
"__name_space__": "mrc_ner",
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json"
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/mrc-ner.train"
},

"validation_dataset": {
"__type__": "MSRAFlatNerDataset",
"__name_space__": "mrc_ner",
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json"
"dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/mrc-ner.dev"
},

"model_collate": {
Expand Down Expand Up @@ -84,6 +84,6 @@

"devices": ["cpu"],
"serialize_dir": "/Users/panxu/MyProjects/github/easytext/data/mrc_ner/serialize",
"train_batch_size": 4,
"train_batch_size": 1,
"test_batch_size": 8
}
18 changes: 15 additions & 3 deletions easytext/utils/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,25 @@
Authors: PanXu
Date: 2020/11/03 18:11:00
"""

from transformers import BertTokenizer
from typing import List, Tuple
from transformers import BertTokenizerFast

from easytext.component.register import ComponentRegister
from easytext.component.component_builtin_key import ComponentBuiltinKey


@ComponentRegister.register(typename="BertTokenizer", name_space=ComponentBuiltinKey.EASYTEXT_NAME_SPACE)
def bert_tokenizer(bert_dir: str):
return BertTokenizer.from_pretrained(bert_dir)
return BertTokenizerFast.from_pretrained(bert_dir)


def mapping_label(token_offset_mapping: List[Tuple[int, int]], labels: List) -> List:
"""
bert tokenizer 会对 index 进行转化, 所以需要将其转换后的 index 与 label 对应起来。
转换前 index = label index -> 转换后 index -> 转换后的 label index
:param token_offset_mapping: bert tokenizer 返回的 offset_mapping, [(being, end), ...], 其 list index 就是
转换前的 token index
:param labels: 标签列表
:return:
"""

73 changes: 53 additions & 20 deletions mrc/data/bert_model_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,25 @@ def __init__(self, tokenizer: BertTokenizer, max_length: int = 128):
def __call__(self, instances: List[Instance]) -> MRCModelInputs:

batch_size = len(instances)
# 获取当前 batch 最大长度
# 3 表示: CLS, SEP, SEP 3个 special token
batch_max_length = max(len(instance["context"] + instance["query"]) + 3 for instance in instances)

batch_max_length = min(batch_max_length, self._max_length)

batch_text_pairs = [[instance["query"], instance["context"]] for instance in instances]
batch_text_pairs = [(instance["query"], instance["context"]) for instance in instances]

batch_inputs = self._tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text_pairs,
truncation=True,
padding=True,
max_length=batch_max_length,
max_length=self._max_length,
return_length=True,
add_special_tokens=True,
return_special_tokens_mask=True,
return_offsets_mapping=True,
# return_token_type_ids=True,
return_tensors="pt")

batch_token_ids = batch_inputs["input_ids"]
batch_token_type_ids = batch_inputs["token_type_ids"]
batch_max_len = max(batch_inputs["length"])

batch_offset_mapping = batch_inputs["offset_mapping"]
batch_special_tokens_mask = batch_inputs["special_tokens_mask"]

# 将special_tokens_mask 0->1, 1->0, 就变成了 sequence 去掉 CLS 和 SEP 的 mask 了
Expand All @@ -74,9 +75,11 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
batch_metadata = list()

# start, end position 处理偏移
for instance in instances:
for instance, token_ids, token_type_ids, offset_mapping in zip(instances,
batch_token_ids,
batch_token_type_ids,
batch_offset_mapping):

query_offset = 1 + len(instance["query"]) + 1 # CLS + query + SEP
start_positions = instance.get("start_positions", None)
end_positions = instance.get("end_positions", None)

Expand All @@ -86,33 +89,63 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
batch_metadata.append(metadata)

if start_positions is not None and end_positions is not None:

# 是因为在 offset 中, 对于 index 的设置,就是 [start, end)
end_positions = [end_pos + 1 for end_pos in end_positions]
instance["end_positions"] = end_positions

# 因为 query 和 context 拼接在一起了,所以 start_position 和 end_position 的位置要重新映射
origin_offset2token_idx_start = {}
origin_offset2token_idx_end = {}

for token_idx in range(len(token_ids)):
# query 的需要过滤
if token_type_ids[token_idx] == 0:
continue

# 获取每一个 token_start 和 end
token_start, token_end = offset_mapping[token_idx]
token_start = token_start.item()
token_end = token_end.item()

# skip [CLS] or [SEP], offset 中 (0, 0) 表示的就是 CLS 或者 SEP
if token_start == token_end == 0:
continue

# token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的
# token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置
origin_offset2token_idx_start[token_start] = token_idx
origin_offset2token_idx_end[token_end] = token_idx

# 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]

metadata["positions"] = zip(start_positions, end_positions)

start_positions = [(query_offset + start_position) for start_position in start_positions]
start_position_labels = torch.zeros(batch_max_length, dtype=torch.long)
start_position_labels = torch.zeros(batch_max_len, dtype=torch.long)

for start_position in start_positions:
if start_position < batch_max_length - 1:
for start_position in new_start_positions:
if start_position < batch_max_len - 1:
start_position_labels[start_position] = 1

batch_start_position_labels.append(start_position_labels)

end_positions = [(query_offset + end_position) for end_position in end_positions]
end_position_labels = torch.zeros(batch_max_length, dtype=torch.long)
end_position_labels = torch.zeros(batch_max_len, dtype=torch.long)

for end_position in end_positions:
for end_position in new_end_positions:

if end_position < batch_max_length - 1:
if end_position < batch_max_len - 1:
end_position_labels[end_position] = 1

batch_end_position_labels.append(end_position_labels)

# match position
match_positions = torch.zeros(size=(batch_max_length, batch_max_length), dtype=torch.long)
match_positions = torch.zeros(size=(batch_max_len, batch_max_len), dtype=torch.long)

for start_position, end_position in zip(start_positions, end_positions):
for start_position, end_position in zip(new_start_positions, new_end_positions):

if start_position < batch_max_length - 1 and end_position < batch_max_length - 1:
if start_position < batch_max_len - 1 and end_position < batch_max_len - 1:
match_positions[start_position, end_position] = 1

batch_match_positions.append(match_positions)
Expand Down
2 changes: 2 additions & 0 deletions mrc/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from easytext.data import Vocabulary, LabelVocabulary, PretrainedVocabulary
from easytext.data import GloveLoader, SGNSLoader
from easytext.utils import log_util
from easytext.utils.seed_util import set_seed
from easytext.trainer import Launcher, Config
from easytext.distributed import ProcessGroupParameter
from easytext.utils.json_util import json2str
Expand Down Expand Up @@ -129,6 +130,7 @@ def _start(self, rank: Optional[int], world_size: int, device: torch.device) ->


if __name__ == '__main__':
set_seed()
log_util.config(level=logging.INFO)

parser = ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion mrc/tests/metric/test_mrc_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_mrc_metric():
golden_label_dict = {"match_position_labels": golden_match_logits}
mrc_metric = MrcModelMetricAdapter()

metric_dict, target_metric = mrc_metric(model_outputs=model_outputs, golden_label_dict=golden_label_dict)
metric_dict, target_metric = mrc_metric(model_outputs=model_outputs, golden_labels=golden_label_dict)

logging.info(f"metric dict: {json2str(metric_dict)}\ntarget metric: {json2str(target_metric)}")

Expand Down
2 changes: 0 additions & 2 deletions mrc/tests/models/test_mrc_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def test_mrc_ner(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset):

ASSERT.assertTrue(is_tensor_equal(paper_match_labels, match_labels, epsilon=0))



logging.info(f"begin mrc ner")
set_seed()

Expand Down

0 comments on commit a006c77

Please sign in to comment.