Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add Add LlamaCppChatCompletionClient and llama-cpp #5326

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ file-surfer = [
"autogen-agentchat==0.4.5",
"markitdown>=0.0.1a2",
]

llama-cpp = [
"llama-cpp-python>=0.1.9"
]

graphrag = ["graphrag>=1.0.1"]
web-surfer = [
"autogen-agentchat==0.4.5",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from ._llama_cpp_completion_client import LlamaCppChatCompletionClient
except ImportError as e:
raise ImportError(
"Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install autogen-ext[llama-cpp]"
) from e

__all__ = ["LlamaCppChatCompletionClient"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit tests in the python/packages/autogen-ext/tests directory

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will work on this tomorrow

Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import json
import logging # added import
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, cast

from autogen_core import CancellationToken
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FunctionExecutionResultMessage,
ModelInfo,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_core.tools import Tool, ToolSchema
from llama_cpp import (
ChatCompletionRequestAssistantMessage,
ChatCompletionRequestFunctionMessage,
ChatCompletionRequestSystemMessage,
ChatCompletionRequestToolMessage,
ChatCompletionRequestUserMessage,
CreateChatCompletionResponse,
Llama,
)


class LlamaCppChatCompletionClient(ChatCompletionClient):
def __init__(
self,
filename: str,
verbose: bool = True,
**kwargs: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: There is a way to allow typing for **kwargs by using Unpack[] on a TypedDict. See example here:

https://github.com/microsoft/autogen/blob/main/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py#L1195-L1196

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekzhu I will fix this. I'm currently sick so expect a little delay but I hope to take care of this and finish the tests and merge I soon as I recover.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take care!

):
"""
Initialize the LlamaCpp client.
"""
self.logger = logging.getLogger(__name__) # initialize logger
self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # set level based on verbosity
self.llm = Llama(model_path=filename, **kwargs)
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0}

async def create(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
tools: Optional[Sequence[Tool | ToolSchema]] = None,
**kwargs: Any,
) -> CreateResult:
tools = tools or []

# Convert LLMMessage objects to dictionaries with 'role' and 'content'
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []
converted_messages: list[
ChatCompletionRequestSystemMessage
| ChatCompletionRequestUserMessage
| ChatCompletionRequestAssistantMessage
| ChatCompletionRequestUserMessage
| ChatCompletionRequestToolMessage
| ChatCompletionRequestFunctionMessage
] = []
for msg in messages:
if isinstance(msg, SystemMessage):
converted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, UserMessage) and isinstance(msg.content, str):
converted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str):
converted_messages.append({"role": "assistant", "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")

# Add tool descriptions to the system message
tool_descriptions = "\n".join(
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)]
)

few_shot_example = """
Example tool usage:
User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"}
Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"}
"""

system_message = (
"You are an assistant with access to tools. "
"If a user query matches a tool, explicitly invoke it with JSON arguments. "
"Here are the tools available:\n"
f"{tool_descriptions}\n"
f"{few_shot_example}"
)
converted_messages.insert(0, {"role": "system", "content": system_message})

# Debugging outputs
# print(f"DEBUG: System message: {system_message}")
# print(f"DEBUG: Converted messages: {converted_messages}")

# Generate the model response
response = cast(
CreateChatCompletionResponse, self.llm.create_chat_completion(messages=converted_messages, stream=False)
)
self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0)
self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0)

# Parse the response
response_text = response["choices"][0]["message"]["content"]
# print(f"DEBUG: Model response: {response_text}")

# Detect tool usage in the response
if not response_text:
self.logger.debug("DEBUG: No response text found. Returning empty response.")
return CreateResult(
content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False
)

tool_call = await self._detect_and_execute_tool(
response_text, [tool for tool in tools if isinstance(tool, Tool)]
)
if not tool_call:
self.logger.debug("DEBUG: No tool was invoked. Returning raw model response.")
else:
self.logger.debug(f"DEBUG: Tool executed successfully: {tool_call}")

# Create a CreateResult object
finish_reason = response["choices"][0].get("finish_reason")
if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"):
finish_reason = "unknown"
usage = cast(RequestUsage, response.get("usage", {}))
create_result = CreateResult(
content=tool_call if tool_call else response_text,
usage=usage,
finish_reason=finish_reason, # type: ignore
cached=False,
)
return create_result

async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) -> Optional[str]:
"""
Detect if the model is requesting a tool and execute the tool.

:param response_text: The raw response text from the model.
:param tools: A list of available tools.
:return: The result of the tool execution or None if no tool is called.
"""
for tool in tools:
if tool.name.lower() in response_text.lower(): # Case-insensitive matching
self.logger.debug(f"DEBUG: Detected tool '{tool.name}' in response.")
# Extract arguments (if any) from the response
func_args = self._extract_tool_arguments(response_text)
if func_args:
self.logger.debug(f"DEBUG: Extracted arguments for tool '{tool.name}': {func_args}")
else:
self.logger.debug(f"DEBUG: No arguments found for tool '{tool.name}'.")
return f"Error: No valid arguments provided for tool '{tool.name}'."

# Ensure arguments match the tool's args_type
try:
args_model = tool.args_type()
if "request" in args_model.model_fields: # Handle nested arguments
func_args = {"request": func_args}
args_instance = args_model(**func_args)
except Exception as e:
return f"Error parsing arguments for tool '{tool.name}': {e}"

# Execute the tool
try:
if callable(getattr(tool, "run", None)):
result = await cast(Any, tool).run(args=args_instance, cancellation_token=CancellationToken())
if isinstance(result, dict):
return json.dumps(result)
elif callable(getattr(result, "model_dump", None)): # If it's a Pydantic model
return json.dumps(result.model_dump())
else:
return str(result)
except Exception as e:
return f"Error executing tool '{tool.name}': {e}"

return None

def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]:
"""
Extract tool arguments from the response text.

:param response_text: The raw response text.
:return: A dictionary of extracted arguments.
"""
try:
args_start = response_text.find("{")
args_end = response_text.find("}")
if args_start != -1 and args_end != -1:
args_str = response_text[args_start : args_end + 1]
args = json.loads(args_str)
if isinstance(args, dict):
return cast(Dict[str, Any], args)
else:
return {}
except json.JSONDecodeError as e:
self.logger.debug(f"DEBUG: Failed to parse arguments: {e}")
return {}

async def create_stream(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
tools: Optional[Sequence[Tool | ToolSchema]] = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
tools = tools or []

# Convert LLMMessage objects to dictionaries with 'role' and 'content'
converted_messages: list[
ChatCompletionRequestSystemMessage
| ChatCompletionRequestUserMessage
| ChatCompletionRequestAssistantMessage
| ChatCompletionRequestUserMessage
| ChatCompletionRequestToolMessage
| ChatCompletionRequestFunctionMessage
] = []
for msg in messages:
if isinstance(msg, SystemMessage):
converted_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, UserMessage) and isinstance(msg.content, str):
converted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str):
converted_messages.append({"role": "assistant", "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")

# Add tool descriptions to the system message
tool_descriptions = "\n".join(
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)]
)

few_shot_example = """
Example tool usage:
User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"}
Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"}
"""

system_message = (
"You are an assistant with access to tools. "
"If a user query matches a tool, explicitly invoke it with JSON arguments. "
"Here are the tools available:\n"
f"{tool_descriptions}\n"
f"{few_shot_example}"
)
converted_messages.insert(0, {"role": "system", "content": system_message})
# Convert messages into a plain string prompt
prompt = "\n".join(f"{msg['role']}: {msg.get('content', '')}" for msg in converted_messages)
# Call the model with streaming enabled
response_generator = self.llm(prompt=prompt, stream=True)

for token in response_generator:
if isinstance(token, dict):
yield token["choices"][0]["text"]
else:
yield token

# Implement abstract methods
def actual_usage(self) -> RequestUsage:
return RequestUsage(
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
completion_tokens=self._total_usage.get("completion_tokens", 0),
)

@property
def capabilities(self) -> ModelInfo:
return self.model_info
def count_tokens(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
**kwargs: Any,
) -> int:
total = 0
for msg in messages:
# Use the Llama model's tokenizer to encode the content
tokens = self.llm.tokenize(str(msg.content).encode("utf-8"))
total += len(tokens)
return total

@property
def model_info(self) -> ModelInfo:
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True)

def remaining_tokens(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
**kwargs: Any,
) -> int:
used_tokens = self.count_tokens(messages)
return max(self.llm.n_ctx() - used_tokens, 0)

def total_usage(self) -> RequestUsage:
return RequestUsage(
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
completion_tokens=self._total_usage.get("completion_tokens", 0),
)
Loading