diff --git a/data/dataset/mrc_msra_ner/sample.json b/data/dataset/mrc_msra_ner/sample.json index 0cb24fd..917da6c 100644 --- a/data/dataset/mrc_msra_ner/sample.json +++ b/data/dataset/mrc_msra_ner/sample.json @@ -1,14 +1,5 @@ [ - { - "context": "藏 书 本 来 就 是 所 有 传 统 收 藏 门 类 中 的 第 一 大 户 , 只 是 我 们 结 束 温 饱 的 时 间 太 短 而 已 。", - "end_position": [], - "entity_label": "NT", - "impossible": true, - "qas_id": "1.3", - "query": "组织包括公司,政府党派,学校,政府,新闻机构", - "span_position": [], - "start_position": [] - }, + { "context": "因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一 。", "end_position": [ diff --git a/data/mrc_ner/config/config.json b/data/mrc_ner/config/config.json index a6b5af0..4d16361 100644 --- a/data/mrc_ner/config/config.json +++ b/data/mrc_ner/config/config.json @@ -2,13 +2,13 @@ "training_dataset": { "__type__": "MSRAFlatNerDataset", "__name_space__": "mrc_ner", - "dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json" + "dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/mrc-ner.train" }, "validation_dataset": { "__type__": "MSRAFlatNerDataset", "__name_space__": "mrc_ner", - "dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/sample.json" + "dataset_file_path": "/Users/panxu/MyProjects/github/easytext/data/dataset/mrc_msra_ner/mrc-ner.dev" }, "model_collate": { @@ -84,6 +84,6 @@ "devices": ["cpu"], "serialize_dir": "/Users/panxu/MyProjects/github/easytext/data/mrc_ner/serialize", - "train_batch_size": 4, + "train_batch_size": 1, "test_batch_size": 8 } diff --git a/easytext/utils/bert_tokenizer.py b/easytext/utils/bert_tokenizer.py index b21a3ab..75604d0 100644 --- a/easytext/utils/bert_tokenizer.py +++ b/easytext/utils/bert_tokenizer.py @@ -10,8 +10,8 @@ Authors: PanXu Date: 2020/11/03 18:11:00 """ - -from transformers import BertTokenizer +from typing import List, Tuple +from transformers import BertTokenizerFast from easytext.component.register import ComponentRegister from easytext.component.component_builtin_key import ComponentBuiltinKey @@ -19,4 +19,16 @@ @ComponentRegister.register(typename="BertTokenizer", name_space=ComponentBuiltinKey.EASYTEXT_NAME_SPACE) def bert_tokenizer(bert_dir: str): - return BertTokenizer.from_pretrained(bert_dir) + return BertTokenizerFast.from_pretrained(bert_dir) + + +def mapping_label(token_offset_mapping: List[Tuple[int, int]], labels: List) -> List: + """ + bert tokenizer 会对 index 进行转化, 所以需要将其转换后的 index 与 label 对应起来。 + 转换前 index = label index -> 转换后 index -> 转换后的 label index + :param token_offset_mapping: bert tokenizer 返回的 offset_mapping, [(being, end), ...], 其 list index 就是 + 转换前的 token index + :param labels: 标签列表 + :return: + """ + diff --git a/mrc/data/bert_model_collate.py b/mrc/data/bert_model_collate.py index fe231a9..76b0b30 100644 --- a/mrc/data/bert_model_collate.py +++ b/mrc/data/bert_model_collate.py @@ -39,24 +39,25 @@ def __init__(self, tokenizer: BertTokenizer, max_length: int = 128): def __call__(self, instances: List[Instance]) -> MRCModelInputs: batch_size = len(instances) - # 获取当前 batch 最大长度 - # 3 表示: CLS, SEP, SEP 3个 special token - batch_max_length = max(len(instance["context"] + instance["query"]) + 3 for instance in instances) - batch_max_length = min(batch_max_length, self._max_length) - - batch_text_pairs = [[instance["query"], instance["context"]] for instance in instances] + batch_text_pairs = [(instance["query"], instance["context"]) for instance in instances] batch_inputs = self._tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text_pairs, truncation=True, padding=True, - max_length=batch_max_length, + max_length=self._max_length, return_length=True, add_special_tokens=True, return_special_tokens_mask=True, + return_offsets_mapping=True, # return_token_type_ids=True, return_tensors="pt") + batch_token_ids = batch_inputs["input_ids"] + batch_token_type_ids = batch_inputs["token_type_ids"] + batch_max_len = max(batch_inputs["length"]) + + batch_offset_mapping = batch_inputs["offset_mapping"] batch_special_tokens_mask = batch_inputs["special_tokens_mask"] # 将special_tokens_mask 0->1, 1->0, 就变成了 sequence 去掉 CLS 和 SEP 的 mask 了 @@ -74,9 +75,11 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs: batch_metadata = list() # start, end position 处理偏移 - for instance in instances: + for instance, token_ids, token_type_ids, offset_mapping in zip(instances, + batch_token_ids, + batch_token_type_ids, + batch_offset_mapping): - query_offset = 1 + len(instance["query"]) + 1 # CLS + query + SEP start_positions = instance.get("start_positions", None) end_positions = instance.get("end_positions", None) @@ -86,33 +89,63 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs: batch_metadata.append(metadata) if start_positions is not None and end_positions is not None: + + # 是因为在 offset 中, 对于 index 的设置,就是 [start, end) + end_positions = [end_pos + 1 for end_pos in end_positions] + instance["end_positions"] = end_positions + + # 因为 query 和 context 拼接在一起了,所以 start_position 和 end_position 的位置要重新映射 + origin_offset2token_idx_start = {} + origin_offset2token_idx_end = {} + + for token_idx in range(len(token_ids)): + # query 的需要过滤 + if token_type_ids[token_idx] == 0: + continue + + # 获取每一个 token_start 和 end + token_start, token_end = offset_mapping[token_idx] + token_start = token_start.item() + token_end = token_end.item() + + # skip [CLS] or [SEP], offset 中 (0, 0) 表示的就是 CLS 或者 SEP + if token_start == token_end == 0: + continue + + # token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的 + # token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置 + origin_offset2token_idx_start[token_start] = token_idx + origin_offset2token_idx_end[token_end] = token_idx + + # 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置 + new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions] + new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions] + metadata["positions"] = zip(start_positions, end_positions) - start_positions = [(query_offset + start_position) for start_position in start_positions] - start_position_labels = torch.zeros(batch_max_length, dtype=torch.long) + start_position_labels = torch.zeros(batch_max_len, dtype=torch.long) - for start_position in start_positions: - if start_position < batch_max_length - 1: + for start_position in new_start_positions: + if start_position < batch_max_len - 1: start_position_labels[start_position] = 1 batch_start_position_labels.append(start_position_labels) - end_positions = [(query_offset + end_position) for end_position in end_positions] - end_position_labels = torch.zeros(batch_max_length, dtype=torch.long) + end_position_labels = torch.zeros(batch_max_len, dtype=torch.long) - for end_position in end_positions: + for end_position in new_end_positions: - if end_position < batch_max_length - 1: + if end_position < batch_max_len - 1: end_position_labels[end_position] = 1 batch_end_position_labels.append(end_position_labels) # match position - match_positions = torch.zeros(size=(batch_max_length, batch_max_length), dtype=torch.long) + match_positions = torch.zeros(size=(batch_max_len, batch_max_len), dtype=torch.long) - for start_position, end_position in zip(start_positions, end_positions): + for start_position, end_position in zip(new_start_positions, new_end_positions): - if start_position < batch_max_length - 1 and end_position < batch_max_length - 1: + if start_position < batch_max_len - 1 and end_position < batch_max_len - 1: match_positions[start_position, end_position] = 1 batch_match_positions.append(match_positions) diff --git a/mrc/launcher.py b/mrc/launcher.py index a461337..4970f2a 100644 --- a/mrc/launcher.py +++ b/mrc/launcher.py @@ -29,6 +29,7 @@ from easytext.data import Vocabulary, LabelVocabulary, PretrainedVocabulary from easytext.data import GloveLoader, SGNSLoader from easytext.utils import log_util +from easytext.utils.seed_util import set_seed from easytext.trainer import Launcher, Config from easytext.distributed import ProcessGroupParameter from easytext.utils.json_util import json2str @@ -129,6 +130,7 @@ def _start(self, rank: Optional[int], world_size: int, device: torch.device) -> if __name__ == '__main__': + set_seed() log_util.config(level=logging.INFO) parser = ArgumentParser() diff --git a/mrc/tests/metric/test_mrc_metric.py b/mrc/tests/metric/test_mrc_metric.py index 818ff13..a05508b 100644 --- a/mrc/tests/metric/test_mrc_metric.py +++ b/mrc/tests/metric/test_mrc_metric.py @@ -97,7 +97,7 @@ def test_mrc_metric(): golden_label_dict = {"match_position_labels": golden_match_logits} mrc_metric = MrcModelMetricAdapter() - metric_dict, target_metric = mrc_metric(model_outputs=model_outputs, golden_label_dict=golden_label_dict) + metric_dict, target_metric = mrc_metric(model_outputs=model_outputs, golden_labels=golden_label_dict) logging.info(f"metric dict: {json2str(metric_dict)}\ntarget metric: {json2str(target_metric)}") diff --git a/mrc/tests/models/test_mrc_ner.py b/mrc/tests/models/test_mrc_ner.py index f2a1132..73527b4 100644 --- a/mrc/tests/models/test_mrc_ner.py +++ b/mrc/tests/models/test_mrc_ner.py @@ -106,8 +106,6 @@ def test_mrc_ner(mrc_msra_ner_dataset, paper_mrc_msra_ner_dataset): ASSERT.assertTrue(is_tensor_equal(paper_match_labels, match_labels, epsilon=0)) - - logging.info(f"begin mrc ner") set_seed()