diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index f09d1b9636d..e5c14bf3c96 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -11,7 +11,7 @@ import torch from tensordict import ( - maybe_dense_stack, + lazy_stack, NestedKey, TensorDict, TensorDictBase, @@ -386,7 +386,7 @@ def __init__( self.endless_dataloader = self._endless_iter(self.dataloader) if stack_method is None: - stack_method = maybe_dense_stack + stack_method = lazy_stack elif stack_method == "as_nested_tensor": stack_method = as_nested_tensor elif stack_method == "as_padded_tensor": @@ -434,10 +434,14 @@ def _endless_iter(self, obj): while True: yield from obj + # def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: + # td = super()._reset_env_preprocess(tensordict) + # return lazy_stack(list(td.unbind(0))) + # def _load_from_dataloader(self, reset: torch.Tensor | None = None): """Loads a single element from the dataloader, or alternatively from the buffer. - If `reset` is passed, the one element per reset will be loaded. + If `reset` is passed, then one element per reset will be loaded. """ if reset is not None: if not reset.any(): diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index daab91c76d0..c73683838f0 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -10,6 +10,8 @@ import torch from tensordict import ( from_dataclass, + lazy_stack, + LazyStackedTensorDict, maybe_dense_stack, NestedKey, NonTensorData, @@ -20,6 +22,7 @@ TensorDictModule as Mod, TensorDictModuleBase, TensorDictSequential as Seq, + WrapModule, ) from tensordict.utils import _zip_strict @@ -61,6 +64,7 @@ def from_vllm( generate: bool = True, generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, + pad_output: bool = True, ) -> TensorDictModuleBase: """Creates a TensorDictModule from a vLLM model. @@ -151,7 +155,7 @@ def from_vllm( out_keys=["tokens_in"], method_kwargs=tokenizer_kwargs, strict=True, - inplace=False, + inplace="empty", ) else: module_dict["encode"] = Mod( @@ -164,7 +168,7 @@ def from_vllm( in_keys=[text_key, "text_response"], out_keys=["tokens_in", "tokens_response"], strict=True, - inplace=False, + inplace="empty", ) def select(x, y): @@ -196,7 +200,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): ("tokens_in", "attention_mask"), ], strict=False, - inplace=False, + inplace="empty", ) else: module_dict["move_inputs"] = Mod( @@ -205,7 +209,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None): out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], # It's ok if there's no mask strict=False, - inplace=False, + inplace="empty", ) def to_list(tokens, attention_mask): @@ -261,13 +265,27 @@ def to_list(tokens, attention_mask): strict=True, ) - def get_output_tokens_and_log_probs(td): + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + + def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + td = lazy_stack(list(td.unbind(0))) if generate: # When not generate, we don't want to overwrite this - td["tokens_response"] = td["tokens_out"].outputs.token_ids + tokens_response_td = td["tokens_out"].outputs._tensordict.select( + "token_ids", "logprobs", strict=False + ) + if pad_output: + tokens_response_td = tokens_response_td.densify( + layout=torch.strided + ).to_padded_tensor(padding=padding_value) + tokens_response_td.rename_key_("token_ids", "tokens_response") + # td["tokens_response"] = outputs.token_ids if return_log_probs: - td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1) + tokens_response_td.rename_key_("logprobs", "log_probs") + # td["log_probs"] = outputs.logprobs.unsqueeze(-1) + td.update(tokens_response_td) elif not generate: td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) return td @@ -296,32 +314,40 @@ def translate_lps(tokens_response, x): module_dict["to_source_device"] = _maybe_set_device if generate: - module_dict["format"] = Mod( - lambda *x: x, - in_keys=[ - "log_probs", - "tokens_response", - ("tokens_in", "input_ids"), - ("tokens_in", "attention_mask"), - "text_response", - ], - out_keys=[ - "log_probs", - "tokens_response", - token_key, - attention_mask_key, - "text_response", - ], - strict=False, - inplace=False, + in_keys = [ + "log_probs", + "tokens_response", + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_response", + ] + out_keys = [ + "log_probs", + "tokens_response", + token_key, + attention_mask_key, + "text_response", + ] + + def format_td(td): + td = td.select(*in_keys, strict=False) + td.rename_key_(("tokens_in", "input_ids"), token_key) + td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) + return td + + module_dict["format"] = WrapModule( + format_td, + in_keys=in_keys, + out_keys=out_keys, ) + else: module_dict["format"] = Mod( lambda *x: x, in_keys=["log_probs", "tokens_response"], out_keys=["log_probs", "tokens_response"], strict=False, - inplace=False, + inplace="empty", ) return Seq(module_dict, inplace=True)