Skip to content

Commit

Permalink
#52: 修复start end positon 越界问题
Browse files Browse the repository at this point in the history
  • Loading branch information
cjopengler committed Nov 22, 2021
1 parent a006c77 commit 2ddd82a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
23 changes: 10 additions & 13 deletions data/dataset/mrc_msra_ner/sample.json
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
[

{
"context": "因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一",
"context": "根 据 选 举 新 闻 中 心 公 布 的 初 步 统 计 结 果 , 至 夜 晚 2 2 时 3 0 分 投 票 结 束 , 全 港 共 有 1 4 8 9 7 0 5 地 区 直 选 选 民 投 票 , 投 票 率 达 5 3 . 2 9 % , 比 1 9 9 5 年 分 别 增 加 近 5 7 万 人 和 1 7 个 多 百 分 点 , 为 香 港 历 史 最 高 纪 录 ; 共 有 7 7 8 1 3 功 能 界 别 选 举 选 民 投 票 , 投 票 率 达 6 3 . 5 % ; 选 举 委 员 会 选 举 投 票 选 民 达 到 7 9 0 人 , 投 票 率 高 达 9 8 . 7 5 %",
"end_position": [
3,
6,
28
7,
129
],
"entity_label": "NS",
"entity_label": "NT",
"impossible": false,
"qas_id": "2.1",
"query": "按照地理位置划分的国家,城市,乡镇,大洲",
"qas_id": "13563.3",
"query": "组织包括公司,政府党派,学校,政府,新闻机构",
"span_position": [
"3;3",
"6;6",
"27;28"
"2;7",
"125;129"
],
"start_position": [
3,
6,
27
2,
125
]
}
]
23 changes: 21 additions & 2 deletions mrc/data/bert_model_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
return_tensors="pt")

batch_token_ids = batch_inputs["input_ids"]

batch_tokens = list()
for batch_token_id in batch_token_ids.tolist():
tmp_tokens = list()
batch_tokens.append(tmp_tokens)
for token_id in batch_token_id:
token = self._tokenizer.decode(token_id)
tmp_tokens.append(token)


batch_token_type_ids = batch_inputs["token_type_ids"]
batch_max_len = max(batch_inputs["length"])

Expand Down Expand Up @@ -98,6 +108,9 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
origin_offset2token_idx_start = {}
origin_offset2token_idx_end = {}

last_token_start = 0
last_token_end = 0

for token_idx in range(len(token_ids)):
# query 的需要过滤
if token_type_ids[token_idx] == 0:
Expand All @@ -112,14 +125,20 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
if token_start == token_end == 0:
continue

# 保存下最后的 start 和 end
last_token_start = token_start
last_token_end = token_end

# token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的
# token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置
origin_offset2token_idx_start[token_start] = token_idx
origin_offset2token_idx_end[token_end] = token_idx

# 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions
if start <= last_token_start]
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions
if end <= last_token_end]

metadata["positions"] = zip(start_positions, end_positions)

Expand Down

0 comments on commit 2ddd82a

Please sign in to comment.