@@ -215,14 +215,15 @@ def process_queries(self):
215215 for output in outputs :
216216 request_id = int (output .request_id )
217217 vllm_text = output .outputs [0 ].text
218- results .append (vllm_text )
218+ results .append (( vllm_text , len ( output . outputs [ 0 ]. token_ids )) )
219219 query_ids .append (self .query_idx_mapping [request_id ])
220220 qid .append (self .qid_mapping [request_id ])
221221
222222 self .num_samples += len (results )
223223
224- for i , result in enumerate (results ):
224+ for i , result_tuple in enumerate (results ):
225225 # Whisper outputs space in the front and capitalizes things
226+ result , n_tokens = result_tuple
226227 result = result .lower ().strip ()
227228 transcript = []
228229 for s in result :
@@ -233,7 +234,7 @@ def process_queries(self):
233234 assert len (transcript ) == 1
234235 response_array = array .array ('q' , transcript [0 ])
235236
236- self .output_queue .put ((qid [i ], response_array ))
237+ self .output_queue .put ((qid [i ], n_tokens , response_array ))
237238 print (f"Finished { qid [i ]} " )
238239 return True
239240
@@ -330,14 +331,13 @@ def flush_queries(self):
330331 def response_loadgen (self ):
331332 keep_alive = True
332333 while keep_alive :
333- result = self .output_queue .get ()
334- if result is None :
334+ qid , n_tokens , response_array = self .output_queue .get ()
335+ if qid is None :
335336 keep_alive = False
336337 else :
337- qid , response_array = result
338338 bi = response_array .buffer_info ()
339339 response = lg .QuerySampleResponse (qid , bi [0 ],
340- bi [1 ] * response_array .itemsize )
340+ bi [1 ] * response_array .itemsize , n_tokens )
341341 lg .QuerySamplesComplete ([response ])
342342
343343 def stop (self ):
0 commit comments