Skip to content

Commit

Permalink
Inital supertools (TM) Moving all of the tool related support functio…
Browse files Browse the repository at this point in the history
…ns to their own modules
  • Loading branch information
gittb committed Sep 5, 2024
1 parent b3caf7b commit f903054
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 52 deletions.
114 changes: 62 additions & 52 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional, Type
from typing import List, Optional
import json

from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
from pydantic import BaseModel, create_model

from common import model
from common.networking import (
Expand All @@ -32,7 +31,11 @@
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.types.tools import ToolCall

from endpoints.OAI.utils.tools import(
postprocess_tool_call,
generate_strict_schemas
)


def _create_response(
Expand Down Expand Up @@ -434,9 +437,12 @@ async def generate_tool_calls(

# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
create_tool_call_model(data)
if data.tools:
strict_schema = generate_strict_schemas(data)
print(strict_schema)
tool_data = deepcopy(data)
tool_data.json_schema = tool_data.tool_call_schema
#tool_data.json_schema = tool_data.tool_call_schema
tool_data.json_schema = strict_schema # needs strict flag
gen_params = tool_data.to_gen_params()

for idx, gen in enumerate(generations):
Expand Down Expand Up @@ -467,50 +473,54 @@ async def generate_tool_calls(

return generations

def create_tool_call_model(data: ChatCompletionRequest):
"""Create a tool call model to guide model based on the tools spec provided"""
dtypes = {
"integer": int,
"string": str,
"boolean": bool,
"object": dict,
"array": list
}

tool_response_models = []
for tool in data.tools:

name = tool.function.name
params = tool.function.parameters.get('properties', {})
required_params = tool.function.parameters.get('required', [])

model_fields = {}
if params:
for arg_key, arg_val in params.items():
arg_name = arg_key
arg_dtype = dtypes[arg_val['type']]
required = arg_name in required_params
model_fields["name"] = name # this need to be a string with a strict value of name
model_fields["arguments"] = {}

# Use Field to set whether the argument is required or not
if required:
model_fields["arguments"][arg_name] = (arg_dtype, ...)
else:
model_fields["arguments"][arg_name] = (arg_dtype, None)

# Create the Pydantic model for the tool
tool_response_model = create_model(name, **model_fields)
tool_response_models.append(tool_response_model)

print(tool_response_models) # these tool_response_model will go into the tool_call as a union of them, need to format correctly



def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]
# def create_tool_call_model(data: ChatCompletionRequest):
# """Create a tool call model to guide model based on the tools spec provided"""
# dtypes = {
# "integer": int,
# "string": str,
# "boolean": bool,
# "object": dict,
# "array": list
# }

# function_models = []
# for tool in data.tools:

# tool_name = tool.function.name
# raw_params = tool.function.parameters.get('properties', {})
# required_params = tool.function.parameters.get('required', [])

# fields = {}
# if raw_params:
# for arg_key, val_dict in raw_params.items():

# arg_name = arg_key
# arg_dtype = dtypes[val_dict['type']]
# required = arg_name in required_params
# fields[arg_name] = (arg_dtype, ... if required else None)
# if not required:
# arg_dtype = Optional[arg_dtype]

# fields[arg_name] = (arg_dtype, ... if required else None)

# arguments_model = create_model(f"{tool_name}Arguments", **fields)

# function_model = create_model(
# f"{tool_name}Model",
# name=(str, tool_name),
# arguments=(arguments_model, ...)
# )

# function_models.append(function_model)

# fucntion_union = Union[tuple(function_models)]

# tool_response_model = create_model(
# "tools_call_response_model",
# id=(str, ...),
# function=(fucntion_union, ...)
# )

# tool_response_model.model_rebuild()

# return tool_response_model
150 changes: 150 additions & 0 deletions endpoints/OAI/utils/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Support functions to enable tool calling"""

from typing import List, Dict
from copy import deepcopy
import json

from endpoints.OAI.types.tools import ToolCall
from endpoints.OAI.types.chat_completion import ChatCompletionRequest

def postprocess_tool_call(call_str: str) -> List[ToolCall]:
print(call_str)
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]


def generate_strict_schemas(data: ChatCompletionRequest):
base_schema = {
"$defs": {},
"properties": {
"id": {"title": "Id", "type": "string"},
"function": {"title": "Function"},
"type": {"$ref": "#/$defs/Type"}
},
"required": ["id", "function", "type"],
"title": "ModelItem",
"type": "object"
}

function_schemas = []
argument_schemas = {}

for i, tool in enumerate(data.tools):
function_name = f"Function{i+1}" if i > 0 else "Function"
argument_name = f"Arguments{i+1}" if i > 0 else "Arguments"
name_def = f"Name{i+1}" if i > 0 else "Name"

# Create Name definition
base_schema["$defs"][name_def] = {
"const": tool.function.name,
"enum": [tool.function.name],
"title": name_def,
"type": "string"
}

# Create Arguments definition
arg_properties = {}
required_params = tool.function.parameters.get('required', [])
for arg_name, arg_info in tool.function.parameters.get('properties', {}).items():
arg_properties[arg_name] = {
"title": arg_name.capitalize(),
"type": arg_info['type']
}

argument_schemas[argument_name] = {
"properties": arg_properties,
"required": required_params,
"title": argument_name,
"type": "object"
}

# Create Function definition
function_schema = {
"properties": {
"name": {"$ref": f"#/$defs/{name_def}"},
"arguments": {"$ref": f"#/$defs/{argument_name}"}
},
"required": ["name", "arguments"],
"title": function_name,
"type": "object"
}

function_schemas.append({"$ref": f"#/$defs/{function_name}"})
base_schema["$defs"][function_name] = function_schema

# Add argument schemas to $defs
base_schema["$defs"].update(argument_schemas)

# Add Type definition
base_schema["$defs"]["Type"] = {
"const": "function",
"enum": ["function"],
"title": "Type",
"type": "string"
}

# Set up the function property
base_schema["properties"]["function"]["anyOf"] = function_schemas

return base_schema


# def generate_strict_schemas(data: ChatCompletionRequest):
# schema = {
# "type": "object",
# "properties": {
# "name": {"type": "string"},
# "arguments": {
# "type": "object",
# "properties": {},
# "required": []
# }
# },
# "required": ["name", "arguments"]
# }

# function_schemas = []

# for tool in data.tools:
# func_schema = deepcopy(schema)
# func_schema["properties"]["name"]["enum"] = [tool.function.name]
# raw_params = tool.function.parameters.get('properties', {})
# required_params = tool.function.parameters.get('required', [])

# # Add argument properties and required fields
# arg_properties = {}
# for arg_name, arg_type in raw_params.items():
# arg_properties[arg_name] = {"type": arg_type['type']}

# func_schema["properties"]["arguments"]["properties"] = arg_properties
# func_schema["properties"]["arguments"]["required"] = required_params

# function_schemas.append(func_schema)

# return _create_full_schema(function_schemas)

# def _create_full_schema(function_schemas: List) -> Dict:
# # Define the master schema structure with placeholders for function schemas
# tool_call_schema = {
# "$schema": "http://json-schema.org/draft-07/schema#",
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "id": {"type": "string"},
# "function": {
# "type": "object", # Add this line
# "oneOf": function_schemas
# },
# "type": {"type": "string", "enum": ["function"]}
# },
# "required": ["id", "function", "type"]
# }
# }

# print(json.dumps(tool_call_schema, indent=2))
# return tool_call_schema

0 comments on commit f903054

Please sign in to comment.