diff --git a/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja b/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja
new file mode 100644
index 00000000000..07a118c6e76
--- /dev/null
+++ b/examples/tool_chat_template_llama3.1_usr_def_tool_call.jinja
@@ -0,0 +1,131 @@
+{{- bos_token }}
+{%- if custom_tools is defined %}
+ {%- set tools = custom_tools %}
+{%- endif %}
+{%- if not tools_in_user_message is defined %}
+ {%- set tools_in_user_message = false %}
+{%- endif %}
+{%- if not date_string is defined %}
+ {%- set date_string = "26 Jul 2024" %}
+{%- endif %}
+
+{#- This block extracts the system message, so we can slot it into the right place. #}
+{%- if messages[0]['role'] == 'system' %}
+ {%- set system_message = messages[0]['content']|trim %}
+ {%- set messages = messages[1:] %}
+{%- else %}
+ {%- set system_message = "" %}
+{%- endif %}
+
+{#- System message + builtin tools #}
+{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
+{%- if builtin_tools is defined or tools is not none %}
+ {{- "Environment: ipython\n" }}
+{%- endif %}
+{%- if builtin_tools is defined %}
+ {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
+{%- endif %}
+{{- "Cutting Knowledge Date: December 2023\n" }}
+{{- "Today Date: " + date_string + "\n\n" }}
+
+{%- if builtin_tools is defined %}
+ {{- "# Tool Instructions\n"}}
+ {{- "- Always execute python code in messages that you share.\n"}}
+ {{- "- When looking for real time information use relevant functions if available else fallback to brave_search\n\n\n"}}
+{%- endif %}
+
+{%- if tools is not none and not tools_in_user_message %}
+ {{- "You have access to the following functions:\n\n"}}
+
+ {%- for t in tools %}
+ {%- if t.function is defined %}
+ {%- set t = t.function %}
+ {%- endif -%}
+ {{- "Use the function '"+t.name+"' to: "+t.description+"\n"}}
+ {{- t | tojson(indent=4) }}
+ {{- "\n\n" }}
+ {%- endfor %}
+ {{- "If a you choose to call a function ONLY reply in the following format:\n"}}
+ {{- "<{start_tag}={function_name}>{parameters}{end_tag}\n" }}
+ {{- "where\n\n"}}
+ {{- "start_tag => ` a JSON dict with the function argument name as key and function argument value as value.\n"}}
+ {{- "end_tag => ``" }}
+ {{- "\n\n" }}
+ {{- "Here is an example,\n"}}
+ {{- "{\"example_name\": \"example_value\"}"}}
+ {{- "\n\n" }}
+ {{- "Reminder:\n"}}
+ {{- "- Function calls MUST follow the specified format\n"}}
+ {{- "- Required parameters MUST be specified\n"}}
+ {{- "- Only call one function at a time\n"}}
+ {{- "- Put the entire function call reply on one line\n"}}
+ {{- "- Always use the information returned by the function to answer to the user\n"}}
+ {{- "- If there is no relevant function available, do NOT call any function: respond directly to the user\n\n"}}
+
+{%- endif %}
+{{- system_message }}
+{{- "<|eot_id|>" }}
+
+{#- Custom tools are passed in a user message with some extra guidance #}
+{%- if tools_in_user_message and not tools is none %}
+ {#- Extract the first user message so we can plug it in here #}
+ {%- if messages | length != 0 %}
+ {%- set first_user_message = messages[0]['content']|trim %}
+ {%- set messages = messages[1:] %}
+ {%- else %}
+ {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
+{%- endif %}
+ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
+ {{- "Given the following functions, please respond with a JSON for a function call " }}
+ {{- "with its proper arguments that best answers the given prompt.\n\n" }}
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
+ {{- "Do not use variables.\n\n" }}
+ {%- for t in tools %}
+ {{- t | tojson }}
+ {{- "\n\n" }}
+ {%- endfor %}
+ {{- first_user_message + "<|eot_id|>"}}
+{%- endif %}
+
+{%- for message in messages %}
+ {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
+ {%- elif 'tool_calls' in message %}
+ {%- if not message.tool_calls|length == 1 %}
+ {{- raise_exception("This model only supports single tool-calls at once!") }}
+ {%- endif %}
+ {%- set tool_call = message.tool_calls[0].function %}
+ {%- if builtin_tools is defined and tool_call.name in builtin_tools %}
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
+ {{- "<|python_tag|>" + tool_call.name + ".call(" }}
+ {%- for arg_name, arg_val in tool_call.arguments | items %}
+ {{- arg_name + '="' + arg_val + '"' }}
+ {%- if not loop.last %}
+ {{- ", " }}
+ {%- endif %}
+ {%- endfor %}
+ {{- ")" }}
+ {%- else %}
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
+ {{- '' + tool_call.arguments + ''}}
+ {%- endif %}
+ {%- if builtin_tools is defined or tools is not none%}
+ {#- This means we're in ipython mode #}
+ {{- "<|eom_id|>" }}
+ {%- else %}
+ {{- "<|eot_id|>" }}
+ {%- endif %}
+ {%- elif message.role == "tool" or message.role == "ipython" %}
+ {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
+ {%- if message.content is mapping or message.content is iterable %}
+ {{- message.content | tojson }}
+ {%- else %}
+ {{- message.content }}
+ {%- endif %}
+ {{- "<|eot_id|>" }}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
+{%- endif %}
\ No newline at end of file
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index b81dc4e7ad7..1d51fdcc4f1 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -7,6 +7,7 @@
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
+from .llama_usr_defined_tool_parser import Llama3UserDefinedCustomToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
@@ -15,5 +16,6 @@
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
- "PythonicToolParser", "Phi4MiniJsonToolParser"
+ "PythonicToolParser", "Llama3UserDefinedCustomToolParser",
+ "Phi4MiniJsonToolParser"
]
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py
new file mode 100644
index 00000000000..677c3a67c53
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/llama_usr_defined_tool_parser.py
@@ -0,0 +1,268 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import re
+from collections.abc import Sequence
+from typing import Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+def _count_substring(string, substring):
+ """
+ Counts the number of non-overlapping occurrences of a substring in
+ a string.
+
+ Args:
+ string (str): The string to search in.
+ substring (str): The substring to search for.
+
+ Returns:
+ int: The number of non-overlapping occurrences of the substring in
+ the string.
+ """
+ count = 0
+ start = 0
+ while True:
+ start = string.find(substring, start)
+ if start == -1:
+ break
+ count += 1
+ start += len(substring)
+ return count
+
+
+@ToolParserManager.register_module("llama3_user_defined_custom")
+class Llama3UserDefinedCustomToolParser(ToolParser):
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ if isinstance(self.model_tokenizer, MistralTokenizer):
+ logger.error("Detected Mistral tokenizer when using a Llama model")
+ self.model_tokenizer = self.model_tokenizer.tokenizer
+
+ self.prev_tool_call_arr: list[dict] = []
+ self.streamed_args_for_tool: list[str] = []
+ self.is_parsing_toolcall = False
+
+ self.nb_tool_calls = 0
+ self.current_tool_name = ""
+ self.current_tool_call_uuid = ""
+ self.is_current_tool_name_sent = False
+ self.tool_call_start_token: str = "]+)>\{([^}]+)\}(?:|>)?")
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction.")
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+
+ # sanity check; avoid unnecessary processing
+ if self.tool_call_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ else:
+ try:
+ # there are two possible captures - between tags, or between a
+ # tag and end-of-string so the result of
+ # findall is an array of tuples where one is a function call and
+ # the other is None
+ function_call_tuples = self.tool_call_regex.findall(
+ model_output)
+
+ logger.info("function_call_tuples: %s", function_call_tuples)
+ print("function_call_tuples: %s", function_call_tuples)
+
+ # load the JSON, and then use it to build the Function and
+ # Tool Call
+ raw_function_calls = [{
+ "name":
+ match[0],
+ "arguments":
+ json.loads("{" + match[1] + "}")
+ } for match in function_call_tuples]
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_call["name"],
+ # function call args are JSON but as a string
+ arguments=json.dumps(function_call["arguments"],
+ ensure_ascii=False)))
+ for function_call in raw_function_calls
+ ]
+
+ content = model_output[:model_output.
+ find(self.tool_call_start_token)]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None)
+
+ except Exception:
+ logger.exception(
+ "Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract tool calls from a streaming response.
+ Handles format:
+ Returns DeltaMessage with either tool_calls or content.
+ """
+ logger.debug("\n", "=" * 50)
+ logger.debug("STREAMING FUNCTION CALLED")
+ logger.debug("Tool call start token id IDs:",
+ self.tool_call_start_token_id)
+ logger.debug("Tool call precall token id IDs:",
+ self.tool_call_preargs_token_id)
+ logger.debug("Tool call end token id IDs:",
+ self.tool_call_end_token_id)
+ logger.debug("Previous text:", previous_text)
+ logger.debug("Current text:", current_text)
+ logger.debug("Delta text:", delta_text)
+ logger.debug("Previous token IDs:", previous_token_ids)
+ logger.debug("Current token IDs:", current_token_ids)
+ logger.debug("Delta token IDs:", delta_token_ids)
+ logger.debug("Current tool name sent:", self.is_current_tool_name_sent)
+ logger.debug("-" * 50)
+ logger.debug("\n")
+ flags = Allow.ALL if self.is_current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+
+ logger.debug("%s=", delta_token_ids[0]
+ in self.tool_call_start_token_id)
+ if delta_token_ids[0] in self.tool_call_start_token_id:
+ # We possibly have a tool call (not sure yet) we don't stream
+
+ logger.debug(
+ "%s=",
+ _count_substring(current_text, self.tool_call_start_token))
+ if _count_substring(
+ current_text,self.tool_call_start_token) > self.nb_tool_calls \
+ and not self.is_parsing_toolcall :
+
+ self.is_parsing_toolcall = True
+ self.nb_tool_calls += 1 #will serve as id
+ self.current_tool_call_uuid = random_uuid()
+ logger.debug("New tool call detected, id:",
+ self.nb_tool_calls - 1)
+ return None # going to the next iter
+ else:
+ logger.debug("Tool call already parsed, id:",
+ self.nb_tool_calls - 1)
+
+ if self.is_parsing_toolcall and not self.is_current_tool_name_sent:
+ logger.debug("Parsing tool call, id:", self.nb_tool_calls - 1)
+ # We are parsing a tool call, we need to parse the tool name
+ if delta_token_ids != self.tool_call_preargs_token_id:
+ self.current_tool_name += delta_text
+ logger.debug("self.current_tool_name=", self.current_tool_name)
+ return None # moving on to the next iteration
+ else:
+ self.current_tool_name = self.current_tool_name.lstrip('=')
+ self.is_current_tool_name_sent = True
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.nb_tool_calls - 1,
+ type="function",
+ id=f"chatcmpl-tool-{self.current_tool_call_uuid}",
+ function=DeltaFunctionCall(
+ name=self.current_tool_name))
+ ])
+
+ if self.is_current_tool_name_sent:
+ logger.debug("Parsed tool name : ", self.current_tool_name)
+
+ if _count_substring(current_text,
+ self.tool_call_end_token) < self.nb_tool_calls:
+ self.streamed_args_for_tool.append(delta_text)
+ return None # moving on to the next iteration
+ else:
+ # adding back {" at the beginning for valid JSON
+ arguments = '{"' + ''.join(self.streamed_args_for_tool)
+ # removing the end token
+ arguments = arguments.rstrip(self.tool_call_end_token)
+ logger.debug("Concatenated tool call arguments : ", arguments)
+
+ current_tool_args = partial_json_parser.loads(
+ arguments or "{}",
+ flags) if self.streamed_args_for_tool else None
+
+ logger.debug("Parsed tool call arguments : ",
+ current_tool_args)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.nb_tool_calls - 1,
+ type="function",
+ id=f"chatcmpl-tool-{self.current_tool_call_uuid}",
+ function=DeltaFunctionCall(name=self.current_tool_name,
+ arguments=json.dumps(
+ current_tool_args)))
+ ])
+
+ self.reset_state()
+
+ return delta
+ else:
+ logger.debug("No tool call detected, returning just text : ",
+ delta_text)
+ return DeltaMessage(content=delta_text)
+
+ def reset_state(self):
+ self.current_tool_name = ''
+ self.is_parsing_toolcall = False
+ self.is_current_tool_name_sent = False
+ self.streamed_args_for_tool = []
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 4e74c20d366..6c92bd4073d 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -154,11 +154,24 @@ def get_computed_blocks(
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
- # NOTE(woosuk): Since incomplete blocks are not eligible for
- # sharing, `num_computed_tokens` is always a multiple of
- # `block_size`.
- num_computed_tokens = len(computed_blocks) * self.block_size
- return computed_blocks, num_computed_tokens
+ computed_blocks = (
+ self.specialized_manager.find_longest_cache_hit(block_hashes))
+
+ if last_block_hash is not None:
+ # Add back the last block hash if it was removed.
+ block_hashes.append(last_block_hash)
+
+ self.prefix_cache_stats.queries += len(block_hashes)
+ self.prefix_cache_stats.hits += len(computed_blocks)
+
+ # NOTE(woosuk): Since incomplete blocks are not eligible for
+ # sharing, `num_computed_tokens` is always a multiple of
+ # `block_size`.
+ num_computed_tokens = len(computed_blocks) * self.block_size
+ return computed_blocks, num_computed_tokens
+ else:
+ # Skip cache hits for prompt logprobs
+ return [], 0
def allocate_slots(
self,