Skip to content

Commit 5952bd1

Browse files
authored
[Distributed] Fix new token's shape (#1254)
1 parent 8fcb3ba commit 5952bd1

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

dist_run.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
209209
batch_size, seq_len, vocab_size = output.shape
210210

211211
if step != -1:
212+
# `pos` is not provided, so we can use the first token
212213
next_token_logits = output[:, 0, :]
213214
else:
214215
# get the logits for each prompt at the specified positions
@@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
228229
).squeeze(-1)
229230
else:
230231
# Argmax (deterministic)
231-
next_tokens = torch.argmax(next_token_logits, dim=-1)
232+
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
232233

233-
logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
234+
# Token ids in int tensor form
234235
return next_tokens
235236

236237

@@ -247,6 +248,11 @@ def _update_padded_sequence(
247248
# Decode token id into string and print it
248249
def _decode_in_flight(token, tokenizer, tp_rank):
249250
"""decode token ids for all prompts in the batch and log them"""
251+
# `token` is a tensor of shape (batch_size, 1).
252+
# For TiktokenTokenizer, we need to squeeze it to 1D.
253+
# For SentencePieceProcessor, we don't.
254+
if isinstance(tokenizer, TiktokenTokenizer):
255+
token = torch.squeeze(token, dim=1)
250256
token_str = tokenizer.decode(token.tolist())
251257
# print the token string on tp rank 0
252258
if tp_rank == 0:
@@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
530536

531537
# output formatted response via last pp group and tp rank 0
532538
if pp_rank == last_pp_rank and tp_rank == 0:
533-
# `res` is a list of tensors, each being a batch of generated token ids
534-
535-
res_stacked = torch.stack(res, dim=1)
536-
res_list = res_stacked.tolist()
537-
538-
# Decode the output as comprehension instead of loop
539-
responses = [tokenizer.decode(sequence) for sequence in res_list]
540-
539+
# `res` is a list of tensors, each being a batch of generated token ids.
540+
# We need to concatenate them to get the full sequence of generated
541+
# token ids. Thus cat'ing along dim 1.
542+
res = torch.cat(res, dim=1)
543+
res_list = res.tolist()
544+
responses = tokenizer.decode(res_list)
541545
# Show prompts and responses
542546
for prompt_text, response_text in zip(prompt, responses):
543547
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

0 commit comments

Comments
 (0)