From 7f0db72eac2fe42daf9ae0aaedb2e7cc096c9543 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Mar 2025 16:03:08 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- test/test_actors.py | 94 ++++++++- torchrl/envs/utils.py | 6 +- torchrl/modules/llm/vllm_policy.py | 293 +++++++++++++++++++---------- 3 files changed, 289 insertions(+), 104 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index d4d0a3303dc..2d8620a9844 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -10,7 +10,7 @@ import pytest import torch -from tensordict import NonTensorStack, TensorDict +from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor @@ -1122,6 +1122,8 @@ def _run_check( # If from text and not generating, the tokens are not returned for now if not (from_text and not generate): + assert td.tokens_response is not None + assert td.tokens is not None assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1] # The convention is that the response only has new tokens assert ( @@ -1166,26 +1168,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask): ) @pytest.mark.parametrize( - "from_text, tokens, attention_mask", + "pad_output, from_text, tokens, attention_mask", [ - (True, None, None), + (True, True, None, None), + (False, True, None, None), ( + True, False, torch.randint(1024, (1, 10)), torch.ones(1, 10, dtype=torch.int64), ), - (False, torch.randint(1024, (1, 10)), None), + (True, False, torch.randint(1024, (1, 10)), None), ], ) - def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): + def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output): torch.manual_seed(0) from vllm import LLM model = LLM(model="facebook/opt-125m") m_generate = from_vllm( - model, from_text=from_text, generate=True, return_log_probs=True + model, + from_text=from_text, + generate=True, + return_log_probs=True, + pad_output=pad_output, + ) + m_logprobs = from_vllm( + model, from_text=from_text, generate=False, pad_output=pad_output ) - m_logprobs = from_vllm(model, from_text=from_text, generate=False) self._check_lps( m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False ) @@ -1221,6 +1231,76 @@ def _check_lps( td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 ) + @pytest.fixture(scope="module") + def llm_model(self): + import vllm + + llm_model = vllm.LLM("gpt2") + tokenizer = llm_model.get_tokenizer() + tokenizer.pad_token = tokenizer.eos_token + return llm_model + + @pytest.mark.parametrize("pad", [True, False]) + @pytest.mark.parametrize("generate", [True, False]) + def test_vllm_batch_run(self, pad, generate, llm_model): + # Test generate - padding combinations + policy = from_vllm( + llm_model, + from_text=True, + generate=generate, + return_log_probs=True, + pad_output=pad, + generate_kwargs={"max_tokens": 10000}, + ) + if generate: + data = LazyStackedTensorDict( + *TensorDict( + text=NonTensorStack("a string", "another very long string"), + batch_size=[2], + ).unbind(0) + ) + else: + data = LazyStackedTensorDict( + *TensorDict( + text=NonTensorStack("a string", "another very long string"), + text_response=NonTensorStack( + " is a string", " is still a very long string" + ), + batch_size=[2], + ).unbind(0) + ) + output = policy(data) + try: + log_probs = output.get("log_probs") + except Exception: + log_probs = output.get("log_probs", as_list=True) + if pad: + assert isinstance(log_probs, torch.Tensor) + else: + assert isinstance(log_probs, list) + text = output.get("text", as_list=True) + assert isinstance(text, NonTensorStack) + text_response = output.get("text_response", as_list=True) + assert isinstance(text_response, NonTensorStack) + try: + tokens_response = output.get("tokens_response") + except Exception: + tokens_response = output.get("tokens_response", as_list=True) + if pad: + assert isinstance(tokens_response, torch.Tensor) + else: + assert isinstance(tokens_response, list) + try: + tokens = output.get("tokens") + except Exception: + tokens = output.get("tokens", as_list=True) + if not generate: + assert tokens is None + elif pad: + assert isinstance(tokens, torch.Tensor), tokens + else: + assert isinstance(tokens, list) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index eb236e56c4b..0af7937445d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -508,6 +508,7 @@ def _set_single_key( if isinstance(key, str): key = (key,) for k in key: + # TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature try: val = source._get_str(k, None) if is_tensor_collection(val): @@ -528,7 +529,7 @@ def _set_single_key( # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: - if re.match(r"Found more than one unique shape in the tensors", str(err)): + if re.match(r"Failed to stack tensors within a tensordict", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): _set_single_key(s_td, d_td, k, clone=clone, device=device) @@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded): total_key = total_key + (key,) non_empty = False if unravel_key(total_key) not in excluded: + # TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature try: val = source.get(key) if is_tensor_collection(val) and not isinstance( @@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded): # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised except RuntimeError as err: - if re.match(r"Found more than one unique shape in the tensors", str(err)): + if re.match(r"Failed to stack tensors within a tensordict", str(err)): # this is a het key non_empty_local = False for s_td, d_td in zip(source.tensordicts, dest.tensordicts): diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index cac19a57dea..6754d3cd383 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -6,28 +6,24 @@ import collections import importlib.util +from typing import Literal import torch from tensordict import ( from_dataclass, lazy_stack, LazyStackedTensorDict, - maybe_dense_stack, NestedKey, NonTensorData, NonTensorStack, TensorClass, TensorDict, ) -from tensordict.nn import ( - TensorDictModule as Mod, - TensorDictModuleBase, - TensorDictSequential as Seq, - WrapModule, -) +from tensordict.nn import TensorDictModule as Mod, TensorDictModuleBase, WrapModule from tensordict.utils import _zip_strict, expand_as_right from torchrl.data import LLMData +from torchrl.modules.llm.common import CategoricalSequential _has_vllm = importlib.util.find_spec("vllm") @@ -61,6 +57,7 @@ def from_vllm( generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = True, + inplace: Literal[True, False, "empty"] | None = True, ) -> TensorDictModuleBase: """Creates a TensorDictModule from a vLLM model. @@ -121,8 +118,13 @@ def from_vllm( if tokenizer is None: tokenizer = model.get_tokenizer() - # retrieve the padding value - we use this to make the log-probs of pad token = 1 - padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + if pad_output: + # retrieve the padding value - we use this to make the log-probs of pad token = 1 + padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0] + elif not from_text: + raise TypeError("passing tokens without padding isn't supported at the moment.") + else: + padding_value = None if from_text: if generate: @@ -144,7 +146,7 @@ def from_vllm( return_log_probs=return_log_probs, pad_output=pad_output, ) - return Seq(module_dict, inplace=True) + return CategoricalSequential(module_dict, inplace=inplace) def to_list(tokens, attention_mask): @@ -172,6 +174,11 @@ def to_list(tokens, attention_mask): return NonTensorStack(*tokens) +def _prepare(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + return out.update(td) + + def _from_vllm_generate_text( *, tokenizer, @@ -195,52 +202,61 @@ def _from_vllm_generate_text( module_dict = {} if device: module_dict["clear_device"] = _maybe_clear_device - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", True) not in (True,): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": - raise RuntimeError - def tokenize(td): - out = TensorDict(batch_size=td.batch_size, device=td.device) - tokens_in = TensorDict.from_dict( - tokenizer(td.get(text_key), **tokenizer_kwargs) + if pad_output: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + def tokenize(td): + out = TensorDict(batch_size=td.batch_size, device=td.device) + text = td[text_key] + tokens_in = TensorDict.from_dict(tokenizer(text, **tokenizer_kwargs)) + out.set("tokens_in", tokens_in) + return out + + module_dict["encode"] = WrapModule( + tokenize, + in_keys=[text_key], + out_keys=["tokens_in"], ) - out.set("tokens_in", tokens_in) - return out - - module_dict["encode"] = WrapModule( - tokenize, - in_keys=[text_key], - out_keys=["tokens_in"], - ) - - module_dict["to_list"] = Mod( - to_list, - in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], - out_keys=[("tokens_in", "input_ids_list")], - strict=False, - ) + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) + else: + module_dict["prepare"] = WrapModule( + _prepare, + ) if generate_kwargs is None: generate_kwargs = {} - generate_kwargs.setdefault("detokenize", False) + generate_kwargs.setdefault("detokenize", not pad_output) generate_kwargs.setdefault("prompt_logprobs", False) generate_kwargs.setdefault("logprobs", return_log_probs) sampling_params = SamplingParams(**generate_kwargs) + if pad_output: + in_keys = { + "prompt_token_ids": ("tokens_in", "input_ids_list"), + } + else: + in_keys = [text_key] + module_dict["generate"] = Mod( model, method="generate", method_kwargs={"sampling_params": sampling_params}, - in_keys={ - "prompt_token_ids": ("tokens_in", "input_ids_list"), - }, + in_keys=in_keys, out_keys=["tokens_out"], out_to_in_map=True, strict=True, @@ -248,34 +264,45 @@ def tokenize(td): def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) - if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + if td.ndim and not isinstance(td, LazyStackedTensorDict): td = lazy_stack(list(td.unbind(0))) # When not generate, we don't want to overwrite this tokens_response_td = td["tokens_out"].outputs._tensordict.select( - "token_ids", "logprobs", strict=False + "text", "token_ids", "logprobs", strict=False ) if pad_output: tokens_response_td = tokens_response_td.densify( layout=torch.strided ).to_padded_tensor(padding=padding_value) tokens_response_td.rename_key_("token_ids", "tokens_response") + tokens_response_td.rename_key_("text", "text_response") + if not pad_output: + # Then we can safely move the input tokens, but otherwise they + # may need padding + tokens_response_td.update( + td["tokens_out"].select("prompt_token_ids") + ).rename_key_("prompt_token_ids", token_key) + if return_log_probs: - padded_values = tokens_response_td["tokens_response"] == padding_value tokens_response_td.rename_key_("logprobs", "log_probs") - if padded_values.any(): - lps = tokens_response_td["log_probs"] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td["log_probs"] = lps + if pad_output: + padded_values = tokens_response_td["tokens_response"] == padding_value + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps + td.update(tokens_response_td) return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs - module_dict["decode"] = Mod( - tokenizer.batch_decode, - in_keys=["tokens_response"], - out_keys=["text_response"], - ) + if pad_output: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=["tokens_response"], + out_keys=["text_response"], + ) if device: module_dict["to_source_device"] = _maybe_set_device @@ -283,23 +310,27 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): in_keys = [ "log_probs", "tokens_response", - ("tokens_in", "input_ids"), - ("tokens_in", "attention_mask"), "text_response", + token_key, + "tokens_in", ] out_keys = [ "log_probs", "tokens_response", + "text_response", token_key, attention_mask_key, - "text_response", ] def format_td(td): td = td.select(*in_keys, strict=False) - td.rename_key_(("tokens_in", "input_ids"), token_key) - td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) - del td["tokens_in"] + # We might already have the tokens + if ("tokens_in", "input_ids") in td: + td.rename_key_(("tokens_in", "input_ids"), token_key) + if "tokens_in" in td: + if ("tokens_in", "attention_mask") in td: + td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key) + del td["tokens_in"] return td module_dict["format"] = WrapModule( @@ -387,12 +418,13 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): ).to_padded_tensor(padding=padding_value) tokens_response_td.rename_key_("token_ids", "tokens_response") if return_log_probs: - padded_values = tokens_response_td["tokens_response"] == padding_value tokens_response_td.rename_key_("logprobs", "log_probs") - if padded_values.any(): - lps = tokens_response_td["log_probs"] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td["log_probs"] = lps + if pad_output: + padded_values = tokens_response_td["tokens_response"] == padding_value + if padded_values.any(): + lps = tokens_response_td["log_probs"] + lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) + tokens_response_td["log_probs"] = lps td.update(tokens_response_td) return td @@ -457,27 +489,66 @@ def _from_vllm_logprobs_text( tokenizer_kwargs = {} if not tokenizer_kwargs.setdefault("return_attention_mask", True): raise RuntimeError - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", True) not in (True,): + if pad_output: + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", pad_output) not in (pad_output,): raise RuntimeError if tokenizer_kwargs.setdefault("padding_side", "left") != "left": raise RuntimeError + # Contrary to the case with generate, we always need the tokenizer here to understand what length is the response + # To do this, we tokenize the prompt+response as well as the prompt, and then recover the response by taking + # the last slice of the tokenized prompt+response (ie, removing the tokens of the prompt). + # We need to do this rather than tokenizing the response because we want to ensure that there is no + # additional tokens, but there is defo room for improvement. def tokenize(td): out = TensorDict(batch_size=td.batch_size, device=td.device) text_prompt = td.get(text_key) + if not isinstance(text_prompt, list): + text_prompt = text_prompt.tolist() text_response = td.get("text_response") - tokens_in = tokenizer( - [_x + _y for _x, _y in zip(text_prompt, text_response)], **tokenizer_kwargs - ) + if not isinstance(text_response, list): + text_response = text_response.tolist() + text = [_x + _y for _x, _y in zip(text_prompt, text_response)] + tokens_in = tokenizer(text, **tokenizer_kwargs) tokens_prompt = tokenizer(text_prompt, **tokenizer_kwargs) - tokens_in = TensorDict.from_dict(tokens_in) + if not pad_output: + tokens_in = TensorDict( + input_ids=NonTensorStack(*tokens_in["input_ids"]), + attention_mask=NonTensorStack(*tokens_in["attention_mask"]), + batch_size=td.batch_size, + ) + prompt_input_ids = tokens_prompt["input_ids"] + prompt_attention_mask = tokens_prompt["attention_mask"] + response_input_ids = [] + for token_total, token_prompt in zip( + tokens_in["input_ids"], prompt_input_ids + ): + response_input_ids.append(token_total[len(token_prompt) :]) + response_input_ids = NonTensorStack(*response_input_ids) + response_attention_mask = [] + for mask, mask_prompt in zip( + tokens_in["attention_mask"], prompt_attention_mask + ): + response_attention_mask.append(mask[len(mask_prompt) :]) + response_attention_mask = NonTensorStack(*response_attention_mask) + tokens_response = TensorDict( + input_ids=response_input_ids, + attention_mask=response_attention_mask, + batch_size=td.batch_size, + ) + else: + tokens_in = TensorDict.from_dict(tokens_in) + tokens_prompt = TensorDict.from_dict(tokens_prompt) + tokens_response = tokens_in.apply( + lambda total_tokens, input_tokens: total_tokens[ + :, input_tokens.shape[1] : + ], + tokens_prompt, + ) + out["tokens_in"] = tokens_in - tokens_response = tokens_in.apply( - lambda total_tokens, input_tokens: total_tokens[:, input_tokens.shape[1] :], - TensorDict.from_dict(tokens_prompt), - ) out["tokens_response"] = tokens_response return out @@ -495,9 +566,17 @@ def tokenize(td): strict=False, ) + if tokenizer is not None: + in_keys = { + "prompt_token_ids": ("tokens_in", "input_ids_list"), + } + else: + in_keys = [text_key] + if generate_kwargs is None: generate_kwargs = {} - generate_kwargs.setdefault("detokenize", False) + # We use the tokens when we pad + generate_kwargs.setdefault("detokenize", not pad_output) generate_kwargs.setdefault("prompt_logprobs", True) generate_kwargs.setdefault("logprobs", return_log_probs) generate_kwargs["max_tokens"] = 1 @@ -507,9 +586,7 @@ def tokenize(td): model, method="generate", method_kwargs={"sampling_params": sampling_params}, - in_keys={ - "prompt_token_ids": ("tokens_in", "input_ids_list"), - }, + in_keys=in_keys, out_keys=["tokens_out"], out_to_in_map=True, strict=True, @@ -517,35 +594,61 @@ def tokenize(td): def get_output_tokens_and_log_probs(td, padding_value=padding_value): td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) - if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict): + if td.ndim and not isinstance(td, LazyStackedTensorDict): td = lazy_stack(list(td.unbind(0))) - td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs + td.update( + td["tokens_out"].select("prompt_token_ids", "prompt_logprobs", strict=False) + ) + del td["tokens_out"] return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs - def translate_lps(tokens_response, x): + def translate_lps(tokens_response, lps): # we disregard the tokens from the prompt to focus on those of the response - padded = tokens_response == padding_value - lps = x[..., -tokens_response.shape[-1] :] - lps = torch.where(~padded, lps, 0.0) + if isinstance(lps, torch.Tensor): + lps = lps[..., -tokens_response.shape[-1] :] + else: + # We use a nested tensor as it will be unbound during writing + lps = torch.nested.nested_tensor( + [lp[..., -len(tr) :] for lp, tr in zip(lps, tokens_response)] + ) + if pad_output: + padded = tokens_response == padding_value + lps = torch.where(~padded, lps, 0.0) return lps module_dict["translate_lps"] = Mod( translate_lps, in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], out_keys=["log_probs"], + get_kwargs={ + "as_list": not pad_output, + "as_padded_tensor": pad_output, + "padding_side": "left", + }, ) if device: module_dict["to_source_device"] = _maybe_set_device - module_dict["format"] = Mod( - lambda *x: x, - in_keys=["log_probs", ("tokens_response", "input_ids")], - out_keys=["log_probs", "tokens_response"], - strict=False, - inplace="empty", + in_keys = ["log_probs", ("tokens_response", "input_ids")] + out_keys = ["log_probs", "tokens_response"] + + def format_td(td): + td = td.select(*in_keys, strict=False) + td.rename_key_(("tokens_response", "input_ids"), "tokens_response") + if not pad_output: + # Turn the list of tokens in a tensor + td["tokens_response"] = torch.nested.nested_tensor( + [torch.tensor(val) for val in td["tokens_response"]] + ) + return td + + module_dict["format"] = WrapModule( + format_td, + in_keys=in_keys, + out_keys=out_keys, ) return module_dict @@ -636,8 +739,8 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value): def translate_lps(tokens_response, lps): # we disregard the tokens from the prompt to focus on those of the response - padded = tokens_response == padding_value lps = lps[..., -tokens_response.shape[-1] :] + padded = tokens_response == padding_value lps = torch.where(~padded, lps, 0.0) return lps @@ -718,7 +821,7 @@ def get_logprob(output): @classmethod def from_request_output(cls, requests): - out = maybe_dense_stack( + out = lazy_stack( [ cls( request_id=request.request_id, From 26dfdc189c7772553800670dc078e319080cebd5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Mar 2025 16:32:35 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- test/test_actors.py | 16 +++++++-- torchrl/modules/llm/vllm_policy.py | 53 ++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index 2d8620a9844..8ea869f0000 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -1242,7 +1242,8 @@ def llm_model(self): @pytest.mark.parametrize("pad", [True, False]) @pytest.mark.parametrize("generate", [True, False]) - def test_vllm_batch_run(self, pad, generate, llm_model): + @pytest.mark.parametrize("use_tensorclass", [True, False]) + def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model): # Test generate - padding combinations policy = from_vllm( llm_model, @@ -1269,6 +1270,8 @@ def test_vllm_batch_run(self, pad, generate, llm_model): batch_size=[2], ).unbind(0) ) + if use_tensorclass: + data = LLMData.from_tensordict(data) output = policy(data) try: log_probs = output.get("log_probs") @@ -1279,9 +1282,16 @@ def test_vllm_batch_run(self, pad, generate, llm_model): else: assert isinstance(log_probs, list) text = output.get("text", as_list=True) - assert isinstance(text, NonTensorStack) + # TODO: this is not ideal... + if use_tensorclass: + assert isinstance(text, list) + else: + assert isinstance(text, NonTensorStack) text_response = output.get("text_response", as_list=True) - assert isinstance(text_response, NonTensorStack) + if use_tensorclass: + assert isinstance(text_response, list) + else: + assert isinstance(text_response, NonTensorStack) try: tokens_response = output.get("tokens_response") except Exception: diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index 6754d3cd383..eabe1116ac8 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -61,22 +61,37 @@ def from_vllm( ) -> TensorDictModuleBase: """Creates a TensorDictModule from a vLLM model. - This function provides a consistent interface across various LLM engines. - - It supports text generation and log probability computation, similar to the Hugging Face Transformers interface. + This function provides a consistent interface across various LLM engines, allowing for text generation and + log probability computation, similar to the Hugging Face Transformers interface. Args: - model (LLM): The vLLM model to wrap. - return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`. - tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. - from_text (bool, optional): Whether the input is text. Defaults to `False`. - device (torch.device, optional): The device to use for computation. Defaults to `None`. - generate (bool, optional): Whether to generate text. Defaults to `True`. - generate_kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. - tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + model (vllm.LLM): The vLLM model to wrap. + return_log_probs (bool, optional): Whether to return log probabilities of the generated tokens. + Defaults to `False`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use for encoding + and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to `None`. + from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to + be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `False`. + device (torch.device, optional): The device to use for computation. If `None`, the default device will be used. + Defaults to `None`. + generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on + the input. If `False`, only log probabilities will be computed. Defaults to `True`. + generate_kwargs (dict, optional): Additional arguments to pass to the model's generate method. These + arguments can control aspects of the generation process, such as temperature and top-k sampling. + Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments to pass to the tokenizer. These arguments can control + aspects of the tokenization process, such as padding and truncation. Defaults to `None`. + pad_output (bool, optional): Whether to pad the output sequences to a uniform length. If `True`, the output + sequences will be padded. If `False`, lists of tokens will be used without padding. Defaults to `True`. + inplace (Literal[True, False, "empty"], optional): Determines how the module should handle in-place + operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be + created. + If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will conserve + type, batch-size and device). Defaults to `True`. Returns: - TensorDictModuleBase: A configured TensorDictModule for the specified model. + TensorDictModuleBase: A configured TensorDictModule for the specified model, capable of handling text or + token inputs and producing generated text or log probabilities. Input Keys: @@ -92,21 +107,21 @@ def from_vllm( Output Keys: - "tokens_response": The generated token sequences. - - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is True). - - "text_response": The generated text (if `from_text` is True and `generate` is True). + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is `True`). + - "text_response": The generated text (if `from_text` is `True` and `generate` is `True`). Example: >>> from vllm import LLM >>> from transformers import AutoTokenizer - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = LLM(model="facebook/opt-125m") + >>> model = LLM("gpt2") + >>> tokenizer = model.get_tokenizer() >>> module = from_vllm( ... model, ... tokenizer=tokenizer, ... from_text=True, ... generate=True ... ) - >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!", "This is another text"), batch_size=1) >>> output_data = module(input_data) >>> print(output_data.text_response) @@ -217,7 +232,9 @@ def _from_vllm_generate_text( def tokenize(td): out = TensorDict(batch_size=td.batch_size, device=td.device) - text = td[text_key] + text = td.get(text_key) + if not isinstance(text, (list, str)): + text = text.tolist() tokens_in = TensorDict.from_dict(tokenizer(text, **tokenizer_kwargs)) out.set("tokens_in", tokens_in) return out From 224920d1fd7493c780f6f98dc7b69fdfbf17811b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 19 Mar 2025 19:13:28 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- test/test_actors.py | 45 ++++++++++++++++++++++++++++-- torchrl/envs/custom/llm.py | 10 ++++++- torchrl/envs/transforms/llm.py | 3 ++ torchrl/modules/llm/vllm_policy.py | 9 +++++- 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index 8ea869f0000..14e8b90a92b 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -15,9 +15,12 @@ from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn + +from torchrl.collectors import SyncDataCollector from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot from torchrl.data.llm import LLMData from torchrl.data.llm.dataset import _has_transformers +from torchrl.envs import LLMEnv from torchrl.modules import ( from_hf_transformers, from_vllm, @@ -42,10 +45,10 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices - from pytorch.rl.test.mocking_classes import NestedCountingEnv + from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv else: from _utils_internal import get_default_devices - from mocking_classes import NestedCountingEnv + from mocking_classes import DummyStrDataLoader, NestedCountingEnv _has_vllm = importlib.util.find_spec("vllm") is not None @@ -1311,6 +1314,44 @@ def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model): else: assert isinstance(tokens, list) + def test_vllm_collection(self): + from vllm import LLM + + llm = LLM("gpt2") + policy = from_vllm( + llm, + from_text=True, + generate=True, + return_log_probs=True, + pad_output=False, + generate_kwargs={"max_tokens": 10}, + ) + self._run_check_collector(policy) + + def test_transformers_collection(self): + ... + + @classmethod + def env_constructor(cls): + dl = DummyStrDataLoader(batch_size=32) + env = LLMEnv.from_dataloader( + dl, batch_size=16, repeats=4, str2str=True, group_repeats=True + ) + assert env.batch_size == (64,) + return env + + def _run_check_collector(self, policy): + collector = SyncDataCollector( + self.env_constructor, + policy=policy, + frames_per_batch=128, + total_frames=512, + use_buffers=False, + ) + for data in collector: + assert isinstance(data, LazyStackedTensorDict) + assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 2745b9b0851..7673d9d82be 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -10,6 +10,7 @@ from tensordict import ( is_leaf_nontensor, + LazyStackedTensorDict, NestedKey, TensorDict, TensorDictBase, @@ -266,6 +267,7 @@ def from_dataloader( stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, repeats: int | None = None, + group_repeats: bool = False, ) -> LLMEnv: """Creates an LLMEnv instance from a dataloader. @@ -331,6 +333,8 @@ def from_dataloader( repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module). + group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that + all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``. Returns: LLMEnv: The created LLMEnv instance. @@ -398,6 +402,8 @@ def from_dataloader( stack_method=stack_method, repeats=repeats, device=device, + group_repeats=group_repeats, + batch_size=batch_size, ) env = LLMEnv( str2str=str2str, @@ -411,7 +417,7 @@ def from_dataloader( no_stack=no_stack, assign_reward=assign_reward, assign_done=assign_done, - batch_size=batch_size if batch_size is not None else primer.batch_size, + batch_size=primer.batch_size, has_attention=has_attention, as_llm_data=as_llm_data, ) @@ -565,6 +571,8 @@ def check_str(): f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, " f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) + if not isinstance(tensordict, LazyStackedTensorDict): + tensordict = LazyStackedTensorDict(*tensordict.unbind(0)) td_reset = tensordict.copy() if td_reset.device != self.device: if self.device is None: diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index 7f572836704..da0673ccb41 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -123,6 +123,9 @@ class DataLoadingPrimer(TensorDictPrimer): .. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a wrapper around :class:`~torchrl.envs.LLMEnv`). + group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that + all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``. + Attributes: dataloader (Iterable[Any]): The dataloader to load data from. endless_dataloader (Iterable[Any]): An endless iterator over the dataloader. diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index eabe1116ac8..a810fe98c1e 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -190,7 +190,12 @@ def to_list(tokens, attention_mask): def _prepare(td): - out = TensorDict(batch_size=td.batch_size, device=td.device) + out = LazyStackedTensorDict( + *[ + TensorDict(batch_size=td.batch_size[1:], device=td.device) + for _ in range(td.batch_size[0]) + ] + ) return out.update(td) @@ -383,6 +388,7 @@ def _from_vllm_generate_tokens( module_dict["clear_device"] = _maybe_clear_device def move_input(td): + # TODO: Use a lazy stack? result = TensorDict(batch_size=td.batch_size, device=td.device) result["tokens_in"] = result.new_empty() result["tokens_in", "input_ids"] = td.get("tokens") @@ -520,6 +526,7 @@ def _from_vllm_logprobs_text( # We need to do this rather than tokenizing the response because we want to ensure that there is no # additional tokens, but there is defo room for improvement. def tokenize(td): + # TODO: Use a lazy stack? out = TensorDict(batch_size=td.batch_size, device=td.device) text_prompt = td.get(text_key) if not isinstance(text_prompt, list): From d91a6013becde5d4aec8e1a36eadc89c6805dfeb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 20 Mar 2025 08:07:00 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- test/test_actors.py | 50 +++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index 14e8b90a92b..d0cf13861e7 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -925,6 +925,18 @@ def test_lmhead_actorvalueoperator(device): @pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") @pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") class TestLLMActor: + @pytest.fixture(scope="module") + def vllm_instance(self): + try: + import vllm + except ImportError: + pytest.skip(reason="missing vllm") + + llm_model = vllm.LLM("gpt2") + tokenizer = llm_model.get_tokenizer() + tokenizer.pad_token = tokenizer.eos_token + return llm_model + @pytest.mark.parametrize( "from_text, generate, return_log_probs, tokens, attention_mask", [ @@ -1008,12 +1020,17 @@ def test_from_hf_transformers( ], ) def test_from_vllm( - self, from_text, generate, return_log_probs, tokens, attention_mask + self, + from_text, + generate, + return_log_probs, + tokens, + attention_mask, + vllm_instance, ): torch.manual_seed(0) - from vllm import LLM - model = LLM(model="facebook/opt-125m") + model = vllm_instance m = from_vllm( model, from_text=from_text, @@ -1184,11 +1201,12 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask): (True, False, torch.randint(1024, (1, 10)), None), ], ) - def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output): + def test_from_vllm_logprobs( + self, from_text, tokens, attention_mask, pad_output, vllm_instance + ): torch.manual_seed(0) - from vllm import LLM - model = LLM(model="facebook/opt-125m") + model = vllm_instance m_generate = from_vllm( model, from_text=from_text, @@ -1234,22 +1252,13 @@ def _check_lps( td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 ) - @pytest.fixture(scope="module") - def llm_model(self): - import vllm - - llm_model = vllm.LLM("gpt2") - tokenizer = llm_model.get_tokenizer() - tokenizer.pad_token = tokenizer.eos_token - return llm_model - @pytest.mark.parametrize("pad", [True, False]) @pytest.mark.parametrize("generate", [True, False]) @pytest.mark.parametrize("use_tensorclass", [True, False]) - def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model): + def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance): # Test generate - padding combinations policy = from_vllm( - llm_model, + vllm_instance, from_text=True, generate=generate, return_log_probs=True, @@ -1314,12 +1323,9 @@ def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model): else: assert isinstance(tokens, list) - def test_vllm_collection(self): - from vllm import LLM - - llm = LLM("gpt2") + def test_vllm_collection(self, vllm_instance): policy = from_vllm( - llm, + vllm_instance, from_text=True, generate=True, return_log_probs=True, From 7c5306cc33a8314037ef49cfbe24404b82059551 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 20 Mar 2025 09:08:57 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- test/test_actors.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index d0cf13861e7..de2e660d5e4 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -1218,7 +1218,13 @@ def test_from_vllm_logprobs( model, from_text=from_text, generate=False, pad_output=pad_output ) self._check_lps( - m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False + m_generate, + m_logprobs, + tokens, + attention_mask, + from_text, + has_logits=False, + tol=1e-1, ) def _check_lps( @@ -1229,6 +1235,7 @@ def _check_lps( attention_mask, from_text, has_logits, + tol=1e-2, ): # Checks that the log-probs gathered with generate=False equate those with generate=True tdin_genetate = self._make_data( @@ -1249,7 +1256,7 @@ def _check_lps( assert td_generate.log_probs.shape == td_generate.tokens_response.shape assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape torch.testing.assert_close( - td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 + td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol ) @pytest.mark.parametrize("pad", [True, False])