@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209
209
batch_size , seq_len , vocab_size = output .shape
210
210
211
211
if step != - 1 :
212
+ # `pos` is not provided, so we can use the first token
212
213
next_token_logits = output [:, 0 , :]
213
214
else :
214
215
# get the logits for each prompt at the specified positions
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228
229
).squeeze (- 1 )
229
230
else :
230
231
# Argmax (deterministic)
231
- next_tokens = torch .argmax (next_token_logits , dim = - 1 )
232
+ next_tokens = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
232
233
233
- logger . info ( f" { color . yellow } Next tokens: { color . blue } { next_tokens } { color . reset } " )
234
+ # Token ids in int tensor form
234
235
return next_tokens
235
236
236
237
@@ -247,6 +248,11 @@ def _update_padded_sequence(
247
248
# Decode token id into string and print it
248
249
def _decode_in_flight (token , tokenizer , tp_rank ):
249
250
"""decode token ids for all prompts in the batch and log them"""
251
+ # `token` is a tensor of shape (batch_size, 1).
252
+ # For TiktokenTokenizer, we need to squeeze it to 1D.
253
+ # For SentencePieceProcessor, we don't.
254
+ if isinstance (tokenizer , TiktokenTokenizer ):
255
+ token = torch .squeeze (token , dim = 1 )
250
256
token_str = tokenizer .decode (token .tolist ())
251
257
# print the token string on tp rank 0
252
258
if tp_rank == 0 :
@@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
530
536
531
537
# output formatted response via last pp group and tp rank 0
532
538
if pp_rank == last_pp_rank and tp_rank == 0 :
533
- # `res` is a list of tensors, each being a batch of generated token ids
534
-
535
- res_stacked = torch .stack (res , dim = 1 )
536
- res_list = res_stacked .tolist ()
537
-
538
- # Decode the output as comprehension instead of loop
539
- responses = [tokenizer .decode (sequence ) for sequence in res_list ]
540
-
539
+ # `res` is a list of tensors, each being a batch of generated token ids.
540
+ # We need to concatenate them to get the full sequence of generated
541
+ # token ids. Thus cat'ing along dim 1.
542
+ res = torch .cat (res , dim = 1 )
543
+ res_list = res .tolist ()
544
+ responses = tokenizer .decode (res_list )
541
545
# Show prompts and responses
542
546
for prompt_text , response_text in zip (prompt , responses ):
543
547
logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
0 commit comments