-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tstesco/benchmark-uplift #63
Changes from all commits
d0b78e6
fa7ab76
870cfe2
c22daa0
3429acf
82515b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
import sys | ||
|
@@ -22,8 +24,10 @@ class RequestFuncInput: | |
prompt_len: int | ||
output_len: int | ||
model: str | ||
model_name: Optional[str] = None | ||
best_of: int = 1 | ||
logprobs: Optional[int] = None | ||
extra_body: Optional[dict] = None | ||
multi_modal_content: Optional[dict] = None | ||
ignore_eos: bool = False | ||
|
||
|
@@ -33,12 +37,20 @@ class RequestFuncOutput: | |
generated_text: str = "" | ||
success: bool = False | ||
latency: float = 0.0 | ||
output_tokens: int = 0 | ||
ttft: float = 0.0 # Time to first token | ||
itl: List[float] = field( | ||
default_factory=list) # List of inter-token latencies | ||
tpot: float = 0.0 # avg next-token latencies | ||
prompt_len: int = 0 | ||
error: str = "" | ||
|
||
# TODO: remove before upstream: remove once we can drop support of Python 3.8 | ||
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) | ||
# introduced in Python 3.9 | ||
def backport_removeprefix(string: str, prefix: str) -> str: | ||
return string[len(prefix):] if string.startswith(prefix) else string | ||
|
||
|
||
async def async_request_tgi( | ||
request_func_input: RequestFuncInput, | ||
|
@@ -47,13 +59,15 @@ async def async_request_tgi( | |
api_url = request_func_input.api_url | ||
assert api_url.endswith("generate_stream") | ||
|
||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
params = { | ||
"best_of": request_func_input.best_of, | ||
"max_new_tokens": request_func_input.output_len, | ||
"do_sample": True, | ||
"temperature": 0.01, # TGI does not accept 0.0 temperature. | ||
"top_p": 0.99, # TGI does not accept 1.0 top_p. | ||
"truncate": request_func_input.prompt_len, | ||
# TGI does not accept ignore_eos flag. | ||
} | ||
payload = { | ||
|
@@ -75,11 +89,11 @@ async def async_request_tgi( | |
continue | ||
chunk_bytes = chunk_bytes.decode("utf-8") | ||
|
||
#NOTE: Sometimes TGI returns a ping response without | ||
# NOTE: Sometimes TGI returns a ping response without | ||
# any data, we should skip it. | ||
if chunk_bytes.startswith(":"): | ||
continue | ||
chunk = remove_prefix(chunk_bytes, "data:") | ||
chunk = backport_removeprefix(chunk_bytes, "data:") | ||
|
||
data = json.loads(chunk) | ||
timestamp = time.perf_counter() | ||
|
@@ -118,7 +132,8 @@ async def async_request_trt_llm( | |
api_url = request_func_input.api_url | ||
assert api_url.endswith("generate_stream") | ||
|
||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
assert request_func_input.best_of == 1 | ||
payload = { | ||
"accumulate_tokens": True, | ||
|
@@ -144,15 +159,14 @@ async def async_request_trt_llm( | |
if not chunk_bytes: | ||
continue | ||
|
||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), | ||
"data:") | ||
chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data:") | ||
|
||
data = json.loads(chunk) | ||
output.generated_text += data["text_output"] | ||
timestamp = time.perf_counter() | ||
# First token | ||
if ttft == 0.0: | ||
ttft = time.perf_counter() - st | ||
ttft = timestamp - st | ||
output.ttft = ttft | ||
|
||
# Decoding phase | ||
|
@@ -182,7 +196,8 @@ async def async_request_deepspeed_mii( | |
request_func_input: RequestFuncInput, | ||
pbar: Optional[tqdm] = None, | ||
) -> RequestFuncOutput: | ||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
assert request_func_input.best_of == 1 | ||
|
||
payload = { | ||
|
@@ -230,17 +245,26 @@ async def async_request_openai_completions( | |
("completions", "profile") | ||
), "OpenAI Completions API URL must end with 'completions' or 'profile'." | ||
|
||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
payload = { | ||
"model": request_func_input.model, | ||
"model": request_func_input.model_name \ | ||
if request_func_input.model_name else request_func_input.model, | ||
"prompt": request_func_input.prompt, | ||
"temperature": 0.0, | ||
"best_of": request_func_input.best_of, | ||
# TODO: remove before upstream: best_of and logprobs not currently supported: https://github.com/tenstorrent/vllm/issues/44 | ||
# "best_of": request_func_input.best_of, | ||
"max_tokens": request_func_input.output_len, | ||
"logprobs": request_func_input.logprobs, | ||
# "logprobs": request_func_input.logprobs, | ||
"stream": True, | ||
"ignore_eos": request_func_input.ignore_eos, | ||
"stream_options": { | ||
"include_usage": True, | ||
}, | ||
} | ||
if request_func_input.ignore_eos: | ||
payload["ignore_eos"] = request_func_input.ignore_eos | ||
if request_func_input.extra_body: | ||
payload.update(request_func_input.extra_body) | ||
headers = { | ||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" | ||
} | ||
|
@@ -249,32 +273,33 @@ async def async_request_openai_completions( | |
output.prompt_len = request_func_input.prompt_len | ||
|
||
generated_text = "" | ||
ttft = 0.0 | ||
st = time.perf_counter() | ||
most_recent_timestamp = st | ||
try: | ||
async with session.post(url=api_url, json=payload, | ||
headers=headers) as response: | ||
if response.status == 200: | ||
first_chunk_received = False | ||
async for chunk_bytes in response.content: | ||
chunk_bytes = chunk_bytes.strip() | ||
if not chunk_bytes: | ||
continue | ||
|
||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), | ||
"data: ") | ||
if chunk == "[DONE]": | ||
latency = time.perf_counter() - st | ||
else: | ||
chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data: ") | ||
if chunk != "[DONE]": | ||
data = json.loads(chunk) | ||
|
||
# NOTE: Some completion API might have a last | ||
# usage summary response without a token so we | ||
# want to check a token was generated | ||
if data["choices"][0]["text"]: | ||
if choices := data.get("choices"): | ||
# Note that text could be empty here | ||
# e.g. for special tokens | ||
text = choices[0].get("text") | ||
timestamp = time.perf_counter() | ||
# First token | ||
if ttft == 0.0: | ||
if not first_chunk_received: | ||
first_chunk_received = True | ||
ttft = time.perf_counter() - st | ||
output.ttft = ttft | ||
|
||
|
@@ -284,11 +309,19 @@ async def async_request_openai_completions( | |
most_recent_timestamp) | ||
|
||
most_recent_timestamp = timestamp | ||
generated_text += data["choices"][0]["text"] | ||
|
||
generated_text += text or "" | ||
elif usage := data.get("usage"): | ||
output.output_tokens = usage.get( | ||
"completion_tokens") | ||
if first_chunk_received: | ||
output.success = True | ||
else: | ||
output.success = False | ||
output.error = ( | ||
"Never received a valid chunk to calculate TTFT." | ||
"This response will be marked as failed!") | ||
output.generated_text = generated_text | ||
output.success = True | ||
output.latency = latency | ||
output.latency = most_recent_timestamp - st | ||
else: | ||
output.error = response.reason or "" | ||
output.success = False | ||
|
@@ -311,23 +344,31 @@ async def async_request_openai_chat_completions( | |
"chat/completions" | ||
), "OpenAI Chat Completions API URL must end with 'chat/completions'." | ||
|
||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: | ||
async with aiohttp.ClientSession(trust_env=True, | ||
timeout=AIOHTTP_TIMEOUT) as session: | ||
content = [{"type": "text", "text": request_func_input.prompt}] | ||
if request_func_input.multi_modal_content: | ||
content.append(request_func_input.multi_modal_content) | ||
payload = { | ||
"model": request_func_input.model, | ||
"model": request_func_input.model_name \ | ||
if request_func_input.model_name else request_func_input.model, | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": content | ||
}, | ||
], | ||
"temperature": 0.0, | ||
"max_tokens": request_func_input.output_len, | ||
"max_completion_tokens": request_func_input.output_len, | ||
"stream": True, | ||
"ignore_eos": request_func_input.ignore_eos, | ||
"stream_options": { | ||
"include_usage": True, | ||
}, | ||
} | ||
if request_func_input.ignore_eos: | ||
payload["ignore_eos"] = request_func_input.ignore_eos | ||
if request_func_input.extra_body: | ||
payload.update(request_func_input.extra_body) | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
|
@@ -349,33 +390,33 @@ async def async_request_openai_chat_completions( | |
if not chunk_bytes: | ||
continue | ||
|
||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), | ||
"data: ") | ||
if chunk == "[DONE]": | ||
latency = time.perf_counter() - st | ||
else: | ||
chunk = backport_removeprefix(chunk_bytes.decode("utf-8"), "data: ") | ||
if chunk != "[DONE]": | ||
timestamp = time.perf_counter() | ||
data = json.loads(chunk) | ||
|
||
delta = data["choices"][0]["delta"] | ||
if delta.get("content", None): | ||
if choices := data.get("choices"): | ||
content = choices[0]["delta"].get("content") | ||
# First token | ||
if ttft == 0.0: | ||
ttft = time.perf_counter() - st | ||
ttft = timestamp - st | ||
output.ttft = ttft | ||
|
||
# Decoding phase | ||
else: | ||
output.itl.append(timestamp - | ||
most_recent_timestamp) | ||
|
||
generated_text += delta["content"] | ||
generated_text += content or "" | ||
elif usage := data.get("usage"): | ||
output.output_tokens = usage.get( | ||
"completion_tokens") | ||
|
||
most_recent_timestamp = timestamp | ||
|
||
output.generated_text = generated_text | ||
output.success = True | ||
output.latency = latency | ||
output.latency = most_recent_timestamp - st | ||
else: | ||
output.error = response.reason or "" | ||
output.success = False | ||
|
@@ -389,14 +430,6 @@ async def async_request_openai_chat_completions( | |
return output | ||
|
||
|
||
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) | ||
# introduced in Python 3.9 | ||
def remove_prefix(text: str, prefix: str) -> str: | ||
if text.startswith(prefix): | ||
return text[len(prefix):] | ||
return text | ||
Comment on lines
-392
to
-397
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They already had a function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should only be here until we can move to python 3.9+, and that hopefully happens before we upstream. I can add e.g.:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are aiming to proceed with the rebase + integration of the dev branch on to upstream in the next week or two, so I'm hesitant to push this since we'll have to remove it again There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
def get_model(pretrained_model_name_or_path: str) -> str: | ||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': | ||
from modelscope import snapshot_download | ||
|
@@ -411,14 +444,35 @@ def get_model(pretrained_model_name_or_path: str) -> str: | |
|
||
|
||
def get_tokenizer( | ||
pretrained_model_name_or_path: str, trust_remote_code: bool | ||
pretrained_model_name_or_path: str, | ||
tokenizer_mode: str = "auto", | ||
trust_remote_code: bool = False, | ||
**kwargs, | ||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: | ||
if pretrained_model_name_or_path is not None and not os.path.exists( | ||
pretrained_model_name_or_path): | ||
pretrained_model_name_or_path = get_model( | ||
pretrained_model_name_or_path) | ||
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, | ||
trust_remote_code=trust_remote_code) | ||
if tokenizer_mode == "slow": | ||
if kwargs.get("use_fast", False): | ||
raise ValueError( | ||
"Cannot use the fast tokenizer in slow tokenizer mode.") | ||
kwargs["use_fast"] = False | ||
if tokenizer_mode == "mistral": | ||
try: | ||
from vllm.transformers_utils.tokenizer import MistralTokenizer | ||
except ImportError as e: | ||
raise ImportError("MistralTokenizer requires vllm package.\n" | ||
"Please install it with `pip install vllm` " | ||
"to use mistral tokenizer mode.") from e | ||
return MistralTokenizer.from_pretrained( | ||
str(pretrained_model_name_or_path)) | ||
else: | ||
return AutoTokenizer.from_pretrained( | ||
pretrained_model_name_or_path, | ||
trust_remote_code=trust_remote_code, | ||
**kwargs, | ||
) | ||
|
||
|
||
ASYNC_REQUEST_FUNCS = { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment to this function specifying why it's needed (& the torch version it's supporting)? EDIT: Nevermind, see other comment about
remove_prefix