Skip to content

[BugFix] Better handling of batches in vllm wrapper #2853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions torchrl/envs/transforms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from tensordict import (
maybe_dense_stack,
lazy_stack,
NestedKey,
TensorDict,
TensorDictBase,
Expand Down Expand Up @@ -386,7 +386,7 @@ def __init__(
self.endless_dataloader = self._endless_iter(self.dataloader)

if stack_method is None:
stack_method = maybe_dense_stack
stack_method = lazy_stack
elif stack_method == "as_nested_tensor":
stack_method = as_nested_tensor
elif stack_method == "as_padded_tensor":
Expand Down Expand Up @@ -434,10 +434,14 @@ def _endless_iter(self, obj):
while True:
yield from obj

# def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
# td = super()._reset_env_preprocess(tensordict)
# return lazy_stack(list(td.unbind(0)))
#
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
"""Loads a single element from the dataloader, or alternatively from the buffer.

If `reset` is passed, the one element per reset will be loaded.
If `reset` is passed, then one element per reset will be loaded.
"""
if reset is not None:
if not reset.any():
Expand Down
78 changes: 52 additions & 26 deletions torchrl/modules/llm/vllm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from tensordict import (
from_dataclass,
lazy_stack,
LazyStackedTensorDict,
maybe_dense_stack,
NestedKey,
NonTensorData,
Expand All @@ -20,6 +22,7 @@
TensorDictModule as Mod,
TensorDictModuleBase,
TensorDictSequential as Seq,
WrapModule,
)
from tensordict.utils import _zip_strict

Expand Down Expand Up @@ -61,6 +64,7 @@ def from_vllm(
generate: bool = True,
generate_kwargs: dict | None = None,
tokenizer_kwargs: dict | None = None,
pad_output: bool = True,
) -> TensorDictModuleBase:
"""Creates a TensorDictModule from a vLLM model.

Expand Down Expand Up @@ -151,7 +155,7 @@ def from_vllm(
out_keys=["tokens_in"],
method_kwargs=tokenizer_kwargs,
strict=True,
inplace=False,
inplace="empty",
)
else:
module_dict["encode"] = Mod(
Expand All @@ -164,7 +168,7 @@ def from_vllm(
in_keys=[text_key, "text_response"],
out_keys=["tokens_in", "tokens_response"],
strict=True,
inplace=False,
inplace="empty",
)

def select(x, y):
Expand Down Expand Up @@ -196,7 +200,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
("tokens_in", "attention_mask"),
],
strict=False,
inplace=False,
inplace="empty",
)
else:
module_dict["move_inputs"] = Mod(
Expand All @@ -205,7 +209,7 @@ def stack_for_logprobs(tokens, tokens_response, attention_mask=None):
out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")],
# It's ok if there's no mask
strict=False,
inplace=False,
inplace="empty",
)

def to_list(tokens, attention_mask):
Expand Down Expand Up @@ -261,13 +265,27 @@ def to_list(tokens, attention_mask):
strict=True,
)

def get_output_tokens_and_log_probs(td):
padding_value = tokenizer(tokenizer.pad_token)["input_ids"][0]

def get_output_tokens_and_log_probs(td, padding_value=padding_value):
td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"])
if pad_output and td.ndim and not isinstance(td, LazyStackedTensorDict):
td = lazy_stack(list(td.unbind(0)))
if generate:
# When not generate, we don't want to overwrite this
td["tokens_response"] = td["tokens_out"].outputs.token_ids
tokens_response_td = td["tokens_out"].outputs._tensordict.select(
"token_ids", "logprobs", strict=False
)
if pad_output:
tokens_response_td = tokens_response_td.densify(
layout=torch.strided
).to_padded_tensor(padding=padding_value)
tokens_response_td.rename_key_("token_ids", "tokens_response")
# td["tokens_response"] = outputs.token_ids
if return_log_probs:
td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1)
tokens_response_td.rename_key_("logprobs", "log_probs")
# td["log_probs"] = outputs.logprobs.unsqueeze(-1)
td.update(tokens_response_td)
elif not generate:
td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1)
return td
Expand Down Expand Up @@ -296,32 +314,40 @@ def translate_lps(tokens_response, x):
module_dict["to_source_device"] = _maybe_set_device

if generate:
module_dict["format"] = Mod(
lambda *x: x,
in_keys=[
"log_probs",
"tokens_response",
("tokens_in", "input_ids"),
("tokens_in", "attention_mask"),
"text_response",
],
out_keys=[
"log_probs",
"tokens_response",
token_key,
attention_mask_key,
"text_response",
],
strict=False,
inplace=False,
in_keys = [
"log_probs",
"tokens_response",
("tokens_in", "input_ids"),
("tokens_in", "attention_mask"),
"text_response",
]
out_keys = [
"log_probs",
"tokens_response",
token_key,
attention_mask_key,
"text_response",
]

def format_td(td):
td = td.select(*in_keys, strict=False)
td.rename_key_(("tokens_in", "input_ids"), token_key)
td.rename_key_(("tokens_in", "attention_mask"), attention_mask_key)
return td

module_dict["format"] = WrapModule(
format_td,
in_keys=in_keys,
out_keys=out_keys,
)

else:
module_dict["format"] = Mod(
lambda *x: x,
in_keys=["log_probs", "tokens_response"],
out_keys=["log_probs", "tokens_response"],
strict=False,
inplace=False,
inplace="empty",
)

return Seq(module_dict, inplace=True)
Expand Down
Loading