@@ -54,6 +54,16 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
54
54
return_tensors = "pt" )
55
55
56
56
batch_token_ids = batch_inputs ["input_ids" ]
57
+
58
+ batch_tokens = list ()
59
+ for batch_token_id in batch_token_ids .tolist ():
60
+ tmp_tokens = list ()
61
+ batch_tokens .append (tmp_tokens )
62
+ for token_id in batch_token_id :
63
+ token = self ._tokenizer .decode (token_id )
64
+ tmp_tokens .append (token )
65
+
66
+
57
67
batch_token_type_ids = batch_inputs ["token_type_ids" ]
58
68
batch_max_len = max (batch_inputs ["length" ])
59
69
@@ -98,6 +108,9 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
98
108
origin_offset2token_idx_start = {}
99
109
origin_offset2token_idx_end = {}
100
110
111
+ last_token_start = 0
112
+ last_token_end = 0
113
+
101
114
for token_idx in range (len (token_ids )):
102
115
# query 的需要过滤
103
116
if token_type_ids [token_idx ] == 0 :
@@ -112,14 +125,20 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
112
125
if token_start == token_end == 0 :
113
126
continue
114
127
128
+ # 保存下最后的 start 和 end
129
+ last_token_start = token_start
130
+ last_token_end = token_end
131
+
115
132
# token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的
116
133
# token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置
117
134
origin_offset2token_idx_start [token_start ] = token_idx
118
135
origin_offset2token_idx_end [token_end ] = token_idx
119
136
120
137
# 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置
121
- new_start_positions = [origin_offset2token_idx_start [start ] for start in start_positions ]
122
- new_end_positions = [origin_offset2token_idx_end [end ] for end in end_positions ]
138
+ new_start_positions = [origin_offset2token_idx_start [start ] for start in start_positions
139
+ if start <= last_token_start ]
140
+ new_end_positions = [origin_offset2token_idx_end [end ] for end in end_positions
141
+ if end <= last_token_end ]
123
142
124
143
metadata ["positions" ] = zip (start_positions , end_positions )
125
144
0 commit comments