Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate vllm for multimodal data #1098

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions src/distilabel/models/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Literal,
Optional,
Tuple,
TypedDict,
Union,
)

Expand Down Expand Up @@ -56,6 +57,7 @@
LLMLogprobs,
LLMOutput,
)
from PIL import Image


LogitsProcessorFn = Union[
Expand All @@ -66,6 +68,20 @@
LogitsProcessors = List[LogitsProcessorFn]


class ImageType(TypedDict):
image: "Image.Image"


class MultiModalDict(TypedDict):
prompt: str
multi_modal_data: ImageType


PreparedInput = Union[str, MultiModalDict]
"""A type alias representing the prepared input for the LLM model, both text
and multimodal."""


class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""`vLLM` library LLM implementation.

Expand Down Expand Up @@ -256,7 +272,7 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

def prepare_input(self, input: "StandardInput") -> str:
def prepare_input(self, input: "StandardInput") -> PreparedInput:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Expand All @@ -269,17 +285,55 @@ def prepare_input(self, input: "StandardInput") -> str:
if self._tokenizer.chat_template is None:
return [item["content"] for item in input if item["role"] == "user"][0]

image = None
if (input[-1]["role"] == "user") and isinstance(input[-1]["content"], list):
input_, image = self._prepare_for_multimodal(input)
else:
input_ = input

prompt: str = (
self._tokenizer.apply_chat_template(
input, # type: ignore
input_, # type: ignore
tokenize=False,
add_generation_prompt=True, # type: ignore
)
if input
if input_
else ""
)

if image:
return { # type: ignore
"prompt": prompt,
"multi_modal_data": {"image": image}, # type: ignore
}
return super().apply_magpie_pre_query_template(prompt, input)

def _prepare_for_multimodal(
self, input: "StandardInput"
) -> Tuple["StandardInput", "Image.Image"]:
"""Prepares the input to run multimodal generation, extracting the image from the input
and returning the input without the image and the image itself as a PIL.Image.Image.
"""
image = None
input_ = []
for item in input:
if (item["role"] == "user") and isinstance(item["content"], list):
image = item["content"][1]["image_url"]["url"] # Image
if isinstance(image, str):
from distilabel.models.image_generation.utils import image_from_str

image = image_from_str(image)
# This is prepared to include images, must be transformed to a MultiModalDict
input_.append(
{
"role": "user",
"content": item["content"][0]["text"],
}
)
else:
input_.append(item)
return input_, image

def _prepare_batches(
self, inputs: List["StructuredInput"]
) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/models/llms/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from typing import Any, Dict, List
from unittest import mock

import numpy as np
import pytest
from openai.pagination import SyncPage
from openai.types import Model
from openai.types.completion import Completion
from openai.types.completion_choice import CompletionChoice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from pydantic import BaseModel
from transformers import AutoTokenizer

Expand Down Expand Up @@ -102,6 +104,11 @@ class Animal(BaseModel):
]


img_str = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k="
np.random.seed(42)
img_pil = Image.fromarray(np.random.randint(0, 255, (100, 100, 3)), "RGB")


class TestvLLM:
@pytest.mark.parametrize(
"multi_structured_output",
Expand Down Expand Up @@ -235,6 +242,129 @@ def test_generate(
result = llm.generate(inputs=formatted_inputs, num_generations=num_generations)
assert result == expected_result

@pytest.mark.parametrize(
"num_generations, expected_result",
[
(
1,
[
{
"generations": ["I'm fine thank you"],
"statistics": {"input_tokens": [21], "output_tokens": [6]},
"logprobs": [
[
[
{"token": "I'm", "logprob": -1},
{"token": "Hello", "logprob": -3},
],
[
{"token": "I'm", "logprob": -1},
{"token": "Hello", "logprob": -3},
],
]
],
}
],
),
# (
# 2,
# [
# {
# "generations": ["I'm fine thank you"] * 2,
# "statistics": {
# "input_tokens": [21, 21],
# "output_tokens": [6, 6],
# },
# "logprobs": [
# [
# [
# {"token": "I'm", "logprob": -1},
# {"token": "Hello", "logprob": -3},
# ],
# [
# {"token": "I'm", "logprob": -1},
# {"token": "Hello", "logprob": -3},
# ],
# ]
# ]
# * 2,
# }
# ],
# ),
],
)
def test_generate_with_images(
self,
num_generations: int,
expected_result: List[Dict[str, Any]],
) -> None:
llm = vLLM(model="dummy")
tokenizer = AutoTokenizer.from_pretrained(
"distilabel-internal-testing/tiny-random-mistral"
)
llm._tokenizer = tokenizer
vllm_mock = mock.MagicMock()
vllm_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer)
# mock the import by hacking sys.modules
# https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally
import sys

if "vllm" not in sys.modules:
sys.modules["vllm"] = vllm_mock
llm._model = vllm_mock

mocked_requests_output = [
mock.Mock( # RequestOutput
outputs=[
mock.Mock( # CompletionOutput
text="I'm fine thank you",
token_ids=[1, 2, 3, 4, 5, 7],
logprobs=[
{
1: mock.Mock(decoded_token="I'm", logprob=-1),
2: mock.Mock(decoded_token="Hello", logprob=-3),
},
{
1: mock.Mock(decoded_token="I'm", logprob=-1),
2: mock.Mock(decoded_token="Hello", logprob=-3),
},
],
)
]
* num_generations,
)
]

llm._model.generate = mock.MagicMock(return_value=mocked_requests_output)
formatted_inputs = [
[
{"role": "system", "content": "sysprompt"},
{
"role": "user",
"content": [
{
"type": "text",
"text": "I'm fine thank you",
},
{
"type": "image_url",
"image_url": {
"url": img_str,
},
},
],
},
]
]
# result = llm.generate(inputs=formatted_inputs, num_generations=num_generations)
# assert result == expected_result
prepared_input = llm.prepare_input(formatted_inputs[0])
prompt = "<s> [INST] sysprompt\n\nI'm fine thank you [/INST]"
assert prepared_input["prompt"] == prompt
from PIL.Image import Image

assert isinstance(prepared_input["multi_modal_data"]["image"], Image)


@mock.patch("openai.OpenAI")
@mock.patch("openai.AsyncOpenAI")
Expand Down
Loading