-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathservice.py
168 lines (139 loc) · 6.51 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from __future__ import annotations
import base64, io, logging, contextlib, traceback, typing, uuid
import bentoml, pydantic, fastapi, PIL.Image, typing_extensions, annotated_types
logger = logging.getLogger(__name__)
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import EngineArgs
class Args(EngineArgs, pydantic.BaseModel):
pass
else:
Args = pydantic.BaseModel
class BentoArgs(Args):
bentovllm_model_id: str = 'meta-llama/Llama-3.2-11B-Vision-Instruct'
bentovllm_max_tokens: int = 8192
disable_log_requests: bool = True
max_log_len: int = 1000
request_logger: typing.Any = None
disable_log_stats: bool = True
use_tqdm_on_load: bool = False
enforce_eager: bool = True
limit_mm_per_prompt: typing.Any = {'image': 1}
max_model_len: int = 16384
max_num_seqs: int = 16
enable_auto_tool_choice: bool = True
tool_call_parser: str = 'pythonic'
tensor_parallel_size: int = 1
@pydantic.model_serializer
def serialize_model(self) -> dict[str, typing.Any]:
return {k: getattr(self, k) for k in self.__class__.model_fields if not k.startswith('bentovllm_')}
bento_args = bentoml.use_arguments(BentoArgs)
openai_api_app = fastapi.FastAPI()
@bentoml.asgi_app(openai_api_app, path='/v1')
@bentoml.service(
name='bentovllm-llama3.2-11b-vision-instruct-service',
traffic={'timeout': 300},
resources={'gpu': bento_args.tensor_parallel_size, 'gpu_type': 'nvidia-a100-80gb'},
envs=[
{'name': 'HF_TOKEN'},
{'name': 'UV_NO_PROGRESS', 'value': '1'},
{'name': 'HF_HUB_DISABLE_PROGRESS_BARS', 'value': '1'},
{'name': 'VLLM_ATTENTION_BACKEND', 'value': 'FLASH_ATTN'},
{'name': 'VLLM_USE_V1', 'value': '1'},
],
labels={'owner': 'bentoml-team', 'type': 'prebuilt'},
image=bentoml.images.Image(python_version='3.11', lock_python_packages=False)
.requirements_file('requirements.txt')
.run('uv pip install --compile-bytecode flashinfer-python --find-links https://flashinfer.ai/whl/cu124/torch2.6'),
)
class VLLM:
model = bentoml.models.HuggingFaceModel(
bento_args.bentovllm_model_id, exclude=['original', '*.pth', '*.pt', 'original/**/*']
)
def __init__(self):
from openai import AsyncOpenAI
self.openai = AsyncOpenAI(base_url='http://127.0.0.1:3000/v1', api_key='dummy')
self.exit_stack = contextlib.AsyncExitStack()
@bentoml.on_startup
async def init_engine(self) -> None:
import vllm.entrypoints.openai.api_server as vllm_api_server
from vllm.utils import FlexibleArgumentParser
from vllm.entrypoints.openai.cli_args import make_arg_parser
args = make_arg_parser(FlexibleArgumentParser()).parse_args([])
args.model = self.model
args.served_model_name = [bento_args.bentovllm_model_id]
for key, value in bento_args.model_dump().items():
setattr(args, key, value)
router = fastapi.APIRouter(lifespan=vllm_api_server.lifespan)
OPENAI_ENDPOINTS = [
['/chat/completions', vllm_api_server.create_chat_completion, ['POST']],
['/models', vllm_api_server.show_available_models, ['GET']],
]
for route, endpoint, methods in OPENAI_ENDPOINTS:
router.add_api_route(path=route, endpoint=endpoint, methods=methods, include_in_schema=True)
openai_api_app.include_router(router)
self.engine = await self.exit_stack.enter_async_context(vllm_api_server.build_async_engine_client(args))
self.model_config = await self.engine.get_model_config()
self.tokenizer = await self.engine.get_tokenizer()
await vllm_api_server.init_app_state(self.engine, self.model_config, openai_api_app.state, args)
@bentoml.on_shutdown
async def teardown_engine(self):
await self.exit_stack.aclose()
@bentoml.api
async def generate(
self,
prompt: str = 'Who are you? Please respond in pirate speak!',
max_tokens: typing_extensions.Annotated[
int, annotated_types.Ge(128), annotated_types.Le(bento_args.bentovllm_max_tokens)
] = bento_args.bentovllm_max_tokens,
) -> typing.AsyncGenerator[str, None]:
from vllm import SamplingParams, TokensPrompt
from vllm.entrypoints.chat_utils import parse_chat_messages, apply_hf_chat_template
messages = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt}]}]
params = SamplingParams(max_tokens=max_tokens)
conversation, _ = parse_chat_messages(messages, self.model_config, self.tokenizer, content_format='string')
prompt = TokensPrompt(
prompt_token_ids=apply_hf_chat_template(
self.tokenizer,
conversation=conversation,
tools=None,
add_generation_prompt=True,
continue_final_message=False,
chat_template=None,
tokenize=True,
)
)
stream = self.engine.generate(request_id=uuid.uuid4().hex, prompt=prompt, sampling_params=params)
cursor = 0
async for request_output in stream:
text = request_output.outputs[0].text
yield text[cursor:]
cursor = len(text)
@bentoml.api
async def sights(
self,
prompt: str = 'Describe the content of the picture',
image: typing.Optional['PIL.Image.Image'] = None,
max_tokens: typing_extensions.Annotated[
int, annotated_types.Ge(128), annotated_types.Le(bento_args.bentovllm_max_tokens)
] = bento_args.bentovllm_max_tokens,
) -> typing.AsyncGenerator[str, None]:
if image:
buffered = io.BytesIO()
image.save(buffered, format='PNG')
img_str = base64.b64encode(buffered.getvalue()).decode()
buffered.close()
image_url = f'data:image/png;base64,{img_str}'
content = [dict(type='image_url', image_url=dict(url=image_url)), dict(type='text', text=prompt)]
else:
content = [dict(type='text', text=prompt)]
messages = [{'role': 'user', 'content': content}]
try:
completion = await self.openai.chat.completions.create(
model=bento_args.bentovllm_model_id, messages=messages, stream=True, max_tokens=max_tokens
)
async for chunk in completion:
yield chunk.choices[0].delta.content or ''
except Exception:
logger.error(traceback.format_exc())
yield 'Internal error found. Check server logs for more information'
return