@@ -54,6 +54,16 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
5454 return_tensors = "pt" )
5555
5656 batch_token_ids = batch_inputs ["input_ids" ]
57+
58+ batch_tokens = list ()
59+ for batch_token_id in batch_token_ids .tolist ():
60+ tmp_tokens = list ()
61+ batch_tokens .append (tmp_tokens )
62+ for token_id in batch_token_id :
63+ token = self ._tokenizer .decode (token_id )
64+ tmp_tokens .append (token )
65+
66+
5767 batch_token_type_ids = batch_inputs ["token_type_ids" ]
5868 batch_max_len = max (batch_inputs ["length" ])
5969
@@ -98,6 +108,9 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
98108 origin_offset2token_idx_start = {}
99109 origin_offset2token_idx_end = {}
100110
111+ last_token_start = 0
112+ last_token_end = 0
113+
101114 for token_idx in range (len (token_ids )):
102115 # query 的需要过滤
103116 if token_type_ids [token_idx ] == 0 :
@@ -112,14 +125,20 @@ def __call__(self, instances: List[Instance]) -> MRCModelInputs:
112125 if token_start == token_end == 0 :
113126 continue
114127
128+ # 保存下最后的 start 和 end
129+ last_token_start = token_start
130+ last_token_end = token_end
131+
115132 # token_start 对应的就是 context 中的实际位置,与 start_position 与 end_position 是对应的
116133 # token_idx 是 query 和 context 拼接在一起后的 index,所以 这就是 start_position 映射后的位置
117134 origin_offset2token_idx_start [token_start ] = token_idx
118135 origin_offset2token_idx_end [token_end ] = token_idx
119136
120137 # 将原始数据中的 start_positions 映射到 拼接 query context 之后的位置
121- new_start_positions = [origin_offset2token_idx_start [start ] for start in start_positions ]
122- new_end_positions = [origin_offset2token_idx_end [end ] for end in end_positions ]
138+ new_start_positions = [origin_offset2token_idx_start [start ] for start in start_positions
139+ if start <= last_token_start ]
140+ new_end_positions = [origin_offset2token_idx_end [end ] for end in end_positions
141+ if end <= last_token_end ]
123142
124143 metadata ["positions" ] = zip (start_positions , end_positions )
125144
0 commit comments