From 3b0d7fcdd1525ef0b3b893cc4a34e437550ff311 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Mar 2025 16:34:33 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/modules/llm/vllm_policy.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index daab91c76d0..06b3d099856 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -10,6 +10,8 @@ import torch from tensordict import ( from_dataclass, + lazy_stack, + LazyStackedTensorDict, maybe_dense_stack, NestedKey, NonTensorData, @@ -151,7 +153,7 @@ def from_vllm( out_keys=["tokens_in"], method_kwargs=tokenizer_kwargs, strict=True, - inplace=False, + inplace="empty", ) else: module_dict["encode"] = Mod( @@ -164,7 +166,7 @@ def from_vllm( in_keys=[text_key, "text_response"], out_keys=["tokens_in", "tokens_response"], strict=True, - inplace=False, + inplace="empty", ) def select(x, y): @@ -196,7 +198,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): ("tokens_in", "attention_mask"), ], strict=False, - inplace=False, + inplace="empty", ) else: module_dict["move_inputs"] = Mod( @@ -205,7 +207,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], # It's ok if there's no mask strict=False, - inplace=False, + inplace="empty", ) def to_list(tokens, attention_mask): @@ -261,13 +263,23 @@ def to_list(tokens, attention_mask): strict=True, ) - def get_output_tokens_and_log_probs(td): + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) if generate: # When not generate, we don't want to overwrite this - td["tokens_response"] = td["tokens_out"].outputs.token_ids + tokens_response_td = td["tokens_out"].outputs._tensordict.select( + "token_ids", "logprobs", strict=False + ) + tokens_response_td.rename_key_("token_ids", "tokens_response") + # td["tokens_response"] = outputs.token_ids if return_log_probs: - td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1) + tokens_response_td.rename_key_("logprobs", "log_probs") + # td["log_probs"] = outputs.logprobs.unsqueeze(-1) + td.update(tokens_response_td) elif not generate: td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) return td @@ -313,7 +325,7 @@ def translate_lps(tokens_response, x): "text_response", ], strict=False, - inplace=False, + inplace="empty", ) else: module_dict["format"] = Mod( @@ -321,7 +333,7 @@ def translate_lps(tokens_response, x): in_keys=["log_probs", "tokens_response"], out_keys=["log_probs", "tokens_response"], strict=False, - inplace=False, + inplace="empty", ) return Seq(module_dict, inplace=True)