diff --git a/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja b/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja new file mode 100644 index 00000000000..07a118c6e76 --- /dev/null +++ b/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja @@ -0,0 +1,131 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} + +{%- if builtin_tools is defined %} + {{- "# Tool Instructions\n"}} + {{- "- Always execute python code in messages that you share.\n"}} + {{- "- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n"}} +{%- endif %} + +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions:\n\n"}} + + {%- for t in tools %} + {%- if t.function is defined %} + {%- set t = t.function %} + {%- endif -%} + {{- "Use the function '"+t.name+"' to: "+t.description+"\n"}} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- "If a you choose to call a function ONLY reply in the following format:\n"}} + {{- "<{start_tag}={function_name}>{parameters}{end_tag}\n" }} + {{- "where\n\n"}} + {{- "start_tag => ` a JSON dict with the function argument name as key and function argument value as value.\n"}} + {{- "end_tag => ``" }} + {{- "\n\n" }} + {{- "Here is an example,\n"}} + {{- "{\"example_name\": \"example_value\"}"}} + {{- "\n\n" }} + {{- "Reminder:\n"}} + {{- "- Function calls MUST follow the specified format\n"}} + {{- "- Required parameters MUST be specified\n"}} + {{- "- Only call one function at a time\n"}} + {{- "- Put the entire function call reply on one line\n"}} + {{- "- Always use the information returned by the function to answer to the user\n"}} + {{- "- If there is no relevant function available, do NOT call any function: respond directly to the user\n\n"}} + +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '' + tool_call.arguments + ''}} + {%- endif %} + {%- if builtin_tools is defined or tools is not none%} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index b81dc4e7ad7..1d51fdcc4f1 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -7,6 +7,7 @@ from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .llama_tool_parser import Llama3JsonToolParser +from .llama_usr_defined_tool_parser import Llama3UserDefinedCustomToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser @@ -15,5 +16,6 @@ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser", "Phi4MiniJsonToolParser" + "PythonicToolParser", "Llama3UserDefinedCustomToolParser", + "Phi4MiniJsonToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py new file mode 100644 index 00000000000..677c3a67c53 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import re +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +def _count_substring(string, substring): + """ + Counts the number of non-overlapping occurrences of a substring in + a string. + + Args: + string (str): The string to search in. + substring (str): The substring to search for. + + Returns: + int: The number of non-overlapping occurrences of the substring in + the string. + """ + count = 0 + start = 0 + while True: + start = string.find(substring, start) + if start == -1: + break + count += 1 + start += len(substring) + return count + + +@ToolParserManager.register_module("llama3_user_defined_custom") +class Llama3UserDefinedCustomToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error("Detected Mistral tokenizer when using a Llama model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + self.is_parsing_toolcall = False + + self.nb_tool_calls = 0 + self.current_tool_name = "" + self.current_tool_call_uuid = "" + self.is_current_tool_name_sent = False + self.tool_call_start_token: str = "]+)>\{([^}]+)\}(?:|>)?") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + logger.info("function_call_tuples: %s", function_call_tuples) + print("function_call_tuples: %s", function_call_tuples) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [{ + "name": + match[0], + "arguments": + json.loads("{" + match[1] + "}") + } for match in function_call_tuples] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Extract tool calls from a streaming response. + Handles format: + Returns DeltaMessage with either tool_calls or content. + """ + logger.debug("\n", "=" * 50) + logger.debug("STREAMING FUNCTION CALLED") + logger.debug("Tool call start token id IDs:", + self.tool_call_start_token_id) + logger.debug("Tool call precall token id IDs:", + self.tool_call_preargs_token_id) + logger.debug("Tool call end token id IDs:", + self.tool_call_end_token_id) + logger.debug("Previous text:", previous_text) + logger.debug("Current text:", current_text) + logger.debug("Delta text:", delta_text) + logger.debug("Previous token IDs:", previous_token_ids) + logger.debug("Current token IDs:", current_token_ids) + logger.debug("Delta token IDs:", delta_token_ids) + logger.debug("Current tool name sent:", self.is_current_tool_name_sent) + logger.debug("-" * 50) + logger.debug("\n") + flags = Allow.ALL if self.is_current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + logger.debug("%s=", delta_token_ids[0] + in self.tool_call_start_token_id) + if delta_token_ids[0] in self.tool_call_start_token_id: + # We possibly have a tool call (not sure yet) we don't stream + + logger.debug( + "%s=", + _count_substring(current_text, self.tool_call_start_token)) + if _count_substring( + current_text,self.tool_call_start_token) > self.nb_tool_calls \ + and not self.is_parsing_toolcall : + + self.is_parsing_toolcall = True + self.nb_tool_calls += 1 #will serve as id + self.current_tool_call_uuid = random_uuid() + logger.debug("New tool call detected, id:", + self.nb_tool_calls - 1) + return None # going to the next iter + else: + logger.debug("Tool call already parsed, id:", + self.nb_tool_calls - 1) + + if self.is_parsing_toolcall and not self.is_current_tool_name_sent: + logger.debug("Parsing tool call, id:", self.nb_tool_calls - 1) + # We are parsing a tool call, we need to parse the tool name + if delta_token_ids != self.tool_call_preargs_token_id: + self.current_tool_name += delta_text + logger.debug("self.current_tool_name=", self.current_tool_name) + return None # moving on to the next iteration + else: + self.current_tool_name = self.current_tool_name.lstrip('=') + self.is_current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.nb_tool_calls - 1, + type="function", + id=f"chatcmpl-tool-{self.current_tool_call_uuid}", + function=DeltaFunctionCall( + name=self.current_tool_name)) + ]) + + if self.is_current_tool_name_sent: + logger.debug("Parsed tool name : ", self.current_tool_name) + + if _count_substring(current_text, + self.tool_call_end_token) < self.nb_tool_calls: + self.streamed_args_for_tool.append(delta_text) + return None # moving on to the next iteration + else: + # adding back {" at the beginning for valid JSON + arguments = '{"' + ''.join(self.streamed_args_for_tool) + # removing the end token + arguments = arguments.rstrip(self.tool_call_end_token) + logger.debug("Concatenated tool call arguments : ", arguments) + + current_tool_args = partial_json_parser.loads( + arguments or "{}", + flags) if self.streamed_args_for_tool else None + + logger.debug("Parsed tool call arguments : ", + current_tool_args) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.nb_tool_calls - 1, + type="function", + id=f"chatcmpl-tool-{self.current_tool_call_uuid}", + function=DeltaFunctionCall(name=self.current_tool_name, + arguments=json.dumps( + current_tool_args))) + ]) + + self.reset_state() + + return delta + else: + logger.debug("No tool call detected, returning just text : ", + delta_text) + return DeltaMessage(content=delta_text) + + def reset_state(self): + self.current_tool_name = '' + self.is_parsing_toolcall = False + self.is_current_tool_name_sent = False + self.streamed_args_for_tool = [] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 4e74c20d366..6c92bd4073d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -154,11 +154,24 @@ def get_computed_blocks( # we shouldn't modify it directly. block_hashes.append(last_block_hash) - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + computed_blocks = ( + self.specialized_manager.find_longest_cache_hit(block_hashes)) + + if last_block_hash is not None: + # Add back the last block hash if it was removed. + block_hashes.append(last_block_hash) + + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens + else: + # Skip cache hits for prompt logprobs + return [], 0 def allocate_slots( self,