Skip to content

Commit 2ddd82a

Browse files
committed
#52: 修复start end positon 越界问题
1 parent a006c77 commit 2ddd82a

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

data/dataset/mrc_msra_ner/sample.json

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
[
22

33
{
4-
"context": "因 有 关 日 寇 在 京 掠 夺 文 物 详 情 , 藏 界 较 为 重 视 , 也 是 我 们 收 藏 北 京 史 料 中 的 要 件 之 一",
4+
"context": "根 据 选 举 新 闻 中 心 公 布 的 初 步 统 计 结 果 , 至 夜 晚 2 2 时 3 0 分 投 票 结 束 , 全 港 共 有 1 4 8 9 7 0 5 地 区 直 选 选 民 投 票 , 投 票 率 达 5 3 . 2 9 % , 比 1 9 9 5 年 分 别 增 加 近 5 7 万 人 和 1 7 个 多 百 分 点 , 为 香 港 历 史 最 高 纪 录 ; 共 有 7 7 8 1 3 功 能 界 别 选 举 选 民 投 票 , 投 票 率 达 6 3 . 5 % ; 选 举 委 员 会 选 举 投 票 选 民 达 到 7 9 0 人 , 投 票 率 高 达 9 8 . 7 5 %",
55
"end_position": [
6-
3,
7-
6,
8-
28
6+
7,
7+
129
98
],
10-
"entity_label": "NS",
9+
"entity_label": "NT",
1110
"impossible": false,
12-
"qas_id": "2.1",
13-
"query": "按照地理位置划分的国家,城市,乡镇,大洲",
11+
"qas_id": "13563.3",
12+
"query": "组织包括公司,政府党派,学校,政府,新闻机构",
1413
"span_position": [
15-
"3;3",
16-
"6;6",
17-
"27;28"
14+
"2;7",
15+
"125;129"
1816
],
1917
"start_position": [
20-
3,
21-
6,
22-
27
18+
2,
19+
125
2320
]
2421
}
2522
]

mrc/data/bert_model_collate.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
5454
return_tensors="pt")
5555

5656
batch_token_ids = batch_inputs["input_ids"]
57+
58+
batch_tokens = list()
59+
for batch_token_id in batch_token_ids.tolist():
60+
tmp_tokens = list()
61+
batch_tokens.append(tmp_tokens)
62+
for token_id in batch_token_id:
63+
token = self._tokenizer.decode(token_id)
64+
tmp_tokens.append(token)
65+
66+
5767
batch_token_type_ids = batch_inputs["token_type_ids"]
5868
batch_max_len = max(batch_inputs["length"])
5969

@@ -98,6 +108,9 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
98108
origin_offset2token_idx_start = {}
99109
origin_offset2token_idx_end = {}
100110

111+
last_token_start = 0
112+
last_token_end = 0
113+
101114
for token_idx in range(len(token_ids)):
102115
# query 的需要过滤
103116
if token_type_ids[token_idx] == 0:
@@ -112,14 +125,20 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
112125
if token_start == token_end == 0:
113126
continue
114127

128+
# 保存下最后的 start 和 end
129+
last_token_start = token_start
130+
last_token_end = token_end
131+
115132
# token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的
116133
# token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置
117134
origin_offset2token_idx_start[token_start] = token_idx
118135
origin_offset2token_idx_end[token_end] = token_idx
119136

120137
# 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置
121-
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]
122-
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]
138+
new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions
139+
if start <= last_token_start]
140+
new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions
141+
if end <= last_token_end]
123142

124143
metadata["positions"] = zip(start_positions, end_positions)
125144

0 commit comments

Comments
 (0)