@@ -51,7 +51,7 @@ class ExtractiveReader:
51
51
```
52
52
"""
53
53
54
- def __init__ (
54
+ def __init__ ( # pylint: disable=too-many-positional-arguments
55
55
self ,
56
56
model : Union [Path , str ] = "deepset/roberta-base-squad2-distilled" ,
57
57
device : Optional [ComponentDevice ] = None ,
@@ -192,8 +192,9 @@ def warm_up(self):
192
192
)
193
193
self .device = ComponentDevice .from_multiple (device_map = DeviceMap .from_hf (self .model .hf_device_map ))
194
194
195
+ @staticmethod
195
196
def _flatten_documents (
196
- self , queries : List [str ], documents : List [List [Document ]]
197
+ queries : List [str ], documents : List [List [Document ]]
197
198
) -> Tuple [List [str ], List [Document ], List [int ]]:
198
199
"""
199
200
Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
@@ -203,8 +204,8 @@ def _flatten_documents(
203
204
query_ids = [i for i , documents_ in enumerate (documents ) for _ in documents_ ]
204
205
return flattened_queries , flattened_documents , query_ids
205
206
206
- def _preprocess (
207
- self , queries : List [str ], documents : List [Document ], max_seq_length : int , query_ids : List [int ], stride : int
207
+ def _preprocess ( # pylint: disable=too-many-positional-arguments
208
+ self , * , queries : List [str ], documents : List [Document ], max_seq_length : int , query_ids : List [int ], stride : int
208
209
) -> Tuple ["torch.Tensor" , "torch.Tensor" , "torch.Tensor" , List ["Encoding" ], List [int ], List [int ]]:
209
210
"""
210
211
Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
@@ -256,6 +257,7 @@ def _preprocess(
256
257
257
258
def _postprocess (
258
259
self ,
260
+ * ,
259
261
start : "torch.Tensor" ,
260
262
end : "torch.Tensor" ,
261
263
sequence_ids : "torch.Tensor" ,
@@ -285,9 +287,9 @@ def _postprocess(
285
287
masked_logits = torch .where (mask , logits , - torch .inf )
286
288
probabilities = torch .sigmoid (masked_logits * self .calibration_factor )
287
289
288
- flat_probabilities = probabilities .flatten (- 2 , - 1 ) # necessary for topk
290
+ flat_probabilities = probabilities .flatten (- 2 , - 1 ) # necessary for top-k
289
291
290
- # topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
292
+ # top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
291
293
# We only keep probability > 0 candidates later on
292
294
candidates = torch .topk (flat_probabilities , answers_per_seq )
293
295
seq_length = logits .shape [- 1 ]
@@ -343,6 +345,7 @@ def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
343
345
344
346
def _nest_answers (
345
347
self ,
348
+ * ,
346
349
start : List [List [int ]],
347
350
end : List [List [int ]],
348
351
probabilities : "torch.Tensor" ,
@@ -526,7 +529,7 @@ def deduplicate_by_overlap(
526
529
return deduplicated_answers
527
530
528
531
@component .output_types (answers = List [ExtractedAnswer ])
529
- def run (
532
+ def run ( # pylint: disable=too-many-positional-arguments
530
533
self ,
531
534
query : str ,
532
535
documents : List [Document ],
@@ -594,9 +597,15 @@ def run(
594
597
no_answer = no_answer if no_answer is not None else self .no_answer
595
598
overlap_threshold = overlap_threshold or self .overlap_threshold
596
599
597
- flattened_queries , flattened_documents , query_ids = self ._flatten_documents (queries , nested_documents )
600
+ flattened_queries , flattened_documents , query_ids = ExtractiveReader ._flatten_documents (
601
+ queries , nested_documents
602
+ )
598
603
input_ids , attention_mask , sequence_ids , encodings , query_ids , document_ids = self ._preprocess (
599
- flattened_queries , flattened_documents , max_seq_length , query_ids , stride
604
+ queries = flattened_queries ,
605
+ documents = flattened_documents ,
606
+ max_seq_length = max_seq_length ,
607
+ query_ids = query_ids ,
608
+ stride = stride ,
600
609
)
601
610
602
611
num_batches = math .ceil (input_ids .shape [0 ] / max_batch_size ) if max_batch_size else 1
@@ -625,7 +634,12 @@ def run(
625
634
end_logits = torch .cat (end_logits_list )
626
635
627
636
start , end , probabilities = self ._postprocess (
628
- start_logits , end_logits , sequence_ids , attention_mask , answers_per_seq , encodings
637
+ start = start_logits ,
638
+ end = end_logits ,
639
+ sequence_ids = sequence_ids ,
640
+ attention_mask = attention_mask ,
641
+ answers_per_seq = answers_per_seq ,
642
+ encodings = encodings ,
629
643
)
630
644
631
645
answers = self ._nest_answers (
0 commit comments