Skip to content

Commit 8ccf8dd

Browse files
committed
rewrite of gen strict schema leveraging pydantic style json schema
1 parent f903054 commit 8ccf8dd

File tree

2 files changed

+63
-166
lines changed

2 files changed

+63
-166
lines changed

endpoints/OAI/utils/chat_completion.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ async def generate_tool_calls(
439439
# FIXME: May not be necessary depending on how the codebase evolves
440440
if data.tools:
441441
strict_schema = generate_strict_schemas(data)
442-
print(strict_schema)
443442
tool_data = deepcopy(data)
444443
#tool_data.json_schema = tool_data.tool_call_schema
445444
tool_data.json_schema = strict_schema # needs strict flag
@@ -472,55 +471,3 @@ async def generate_tool_calls(
472471
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
473472

474473
return generations
475-
476-
# def create_tool_call_model(data: ChatCompletionRequest):
477-
# """Create a tool call model to guide model based on the tools spec provided"""
478-
# dtypes = {
479-
# "integer": int,
480-
# "string": str,
481-
# "boolean": bool,
482-
# "object": dict,
483-
# "array": list
484-
# }
485-
486-
# function_models = []
487-
# for tool in data.tools:
488-
489-
# tool_name = tool.function.name
490-
# raw_params = tool.function.parameters.get('properties', {})
491-
# required_params = tool.function.parameters.get('required', [])
492-
493-
# fields = {}
494-
# if raw_params:
495-
# for arg_key, val_dict in raw_params.items():
496-
497-
# arg_name = arg_key
498-
# arg_dtype = dtypes[val_dict['type']]
499-
# required = arg_name in required_params
500-
# fields[arg_name] = (arg_dtype, ... if required else None)
501-
# if not required:
502-
# arg_dtype = Optional[arg_dtype]
503-
504-
# fields[arg_name] = (arg_dtype, ... if required else None)
505-
506-
# arguments_model = create_model(f"{tool_name}Arguments", **fields)
507-
508-
# function_model = create_model(
509-
# f"{tool_name}Model",
510-
# name=(str, tool_name),
511-
# arguments=(arguments_model, ...)
512-
# )
513-
514-
# function_models.append(function_model)
515-
516-
# fucntion_union = Union[tuple(function_models)]
517-
518-
# tool_response_model = create_model(
519-
# "tools_call_response_model",
520-
# id=(str, ...),
521-
# function=(fucntion_union, ...)
522-
# )
523-
524-
# tool_response_model.model_rebuild()
525-
526-
# return tool_response_model

endpoints/OAI/utils/tools.py

Lines changed: 63 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Support functions to enable tool calling"""
22

33
from typing import List, Dict
4-
from copy import deepcopy
54
import json
65

76
from endpoints.OAI.types.tools import ToolCall
@@ -10,141 +9,92 @@
109
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
1110
print(call_str)
1211
tool_calls = json.loads(call_str)
12+
print(tool_calls)
1313
for tool_call in tool_calls:
1414
tool_call["function"]["arguments"] = json.dumps(
1515
tool_call["function"]["arguments"]
1616
)
1717
return [ToolCall(**tool_call) for tool_call in tool_calls]
1818

19-
20-
def generate_strict_schemas(data: ChatCompletionRequest):
21-
base_schema = {
22-
"$defs": {},
23-
"properties": {
24-
"id": {"title": "Id", "type": "string"},
25-
"function": {"title": "Function"},
26-
"type": {"$ref": "#/$defs/Type"}
27-
},
28-
"required": ["id", "function", "type"],
29-
"title": "ModelItem",
30-
"type": "object"
19+
def generate_strict_schemas(data: ChatCompletionRequest) -> Dict:
20+
# Generate the $defs section
21+
defs = generate_defs(data.tools)
22+
23+
# Generate the root structure (now an array)
24+
root_structure = {
25+
"type": "array",
26+
"items": {"$ref": "#/$defs/ModelItem"}
3127
}
3228

33-
function_schemas = []
34-
argument_schemas = {}
29+
# Combine the $defs and root structure
30+
full_schema = {
31+
"$defs": defs,
32+
**root_structure
33+
}
3534

36-
for i, tool in enumerate(data.tools):
37-
function_name = f"Function{i+1}" if i > 0 else "Function"
38-
argument_name = f"Arguments{i+1}" if i > 0 else "Arguments"
39-
name_def = f"Name{i+1}" if i > 0 else "Name"
35+
return full_schema
36+
37+
def generate_defs(tools: List) -> Dict:
38+
defs = {}
39+
40+
for i, tool in enumerate(tools):
41+
function_name = f"Function{i + 1}" if i > 0 else "Function"
42+
arguments_name = f"Arguments{i + 1}" if i > 0 else "Arguments"
43+
name_const = f"Name{i + 1}" if i > 0 else "Name"
44+
45+
# Generate Arguments schema
46+
defs[arguments_name] = generate_arguments_schema(tool.function.parameters)
4047

41-
# Create Name definition
42-
base_schema["$defs"][name_def] = {
48+
# Generate Name schema
49+
defs[name_const] = {
4350
"const": tool.function.name,
44-
"enum": [tool.function.name],
45-
"title": name_def,
51+
"title": name_const,
4652
"type": "string"
4753
}
4854

49-
# Create Arguments definition
50-
arg_properties = {}
51-
required_params = tool.function.parameters.get('required', [])
52-
for arg_name, arg_info in tool.function.parameters.get('properties', {}).items():
53-
arg_properties[arg_name] = {
54-
"title": arg_name.capitalize(),
55-
"type": arg_info['type']
56-
}
57-
58-
argument_schemas[argument_name] = {
59-
"properties": arg_properties,
60-
"required": required_params,
61-
"title": argument_name,
62-
"type": "object"
63-
}
64-
65-
# Create Function definition
66-
function_schema = {
55+
# Generate Function schema
56+
defs[function_name] = {
57+
"type": "object",
6758
"properties": {
68-
"name": {"$ref": f"#/$defs/{name_def}"},
69-
"arguments": {"$ref": f"#/$defs/{argument_name}"}
59+
"name": {"$ref": f"#/$defs/{name_const}"},
60+
"arguments": {"$ref": f"#/$defs/{arguments_name}"}
7061
},
71-
"required": ["name", "arguments"],
72-
"title": function_name,
73-
"type": "object"
62+
"required": ["name", "arguments"]
7463
}
75-
76-
function_schemas.append({"$ref": f"#/$defs/{function_name}"})
77-
base_schema["$defs"][function_name] = function_schema
7864

79-
# Add argument schemas to $defs
80-
base_schema["$defs"].update(argument_schemas)
81-
82-
# Add Type definition
83-
base_schema["$defs"]["Type"] = {
65+
# Add ModelItem and Type schemas
66+
defs["ModelItem"] = generate_model_item_schema(len(tools))
67+
defs["Type"] = {
8468
"const": "function",
85-
"enum": ["function"],
86-
"title": "Type",
8769
"type": "string"
8870
}
8971

90-
# Set up the function property
91-
base_schema["properties"]["function"]["anyOf"] = function_schemas
92-
93-
return base_schema
94-
95-
96-
# def generate_strict_schemas(data: ChatCompletionRequest):
97-
# schema = {
98-
# "type": "object",
99-
# "properties": {
100-
# "name": {"type": "string"},
101-
# "arguments": {
102-
# "type": "object",
103-
# "properties": {},
104-
# "required": []
105-
# }
106-
# },
107-
# "required": ["name", "arguments"]
108-
# }
72+
return defs
10973

110-
# function_schemas = []
74+
def generate_arguments_schema(parameters: Dict) -> Dict:
75+
properties = {}
76+
required = parameters.get('required', [])
11177

112-
# for tool in data.tools:
113-
# func_schema = deepcopy(schema)
114-
# func_schema["properties"]["name"]["enum"] = [tool.function.name]
115-
# raw_params = tool.function.parameters.get('properties', {})
116-
# required_params = tool.function.parameters.get('required', [])
117-
118-
# # Add argument properties and required fields
119-
# arg_properties = {}
120-
# for arg_name, arg_type in raw_params.items():
121-
# arg_properties[arg_name] = {"type": arg_type['type']}
122-
123-
# func_schema["properties"]["arguments"]["properties"] = arg_properties
124-
# func_schema["properties"]["arguments"]["required"] = required_params
125-
126-
# function_schemas.append(func_schema)
78+
for name, details in parameters.get('properties', {}).items():
79+
properties[name] = {"type": details['type']}
12780

128-
# return _create_full_schema(function_schemas)
81+
return {
82+
"type": "object",
83+
"properties": properties,
84+
"required": required
85+
}
12986

130-
# def _create_full_schema(function_schemas: List) -> Dict:
131-
# # Define the master schema structure with placeholders for function schemas
132-
# tool_call_schema = {
133-
# "$schema": "http://json-schema.org/draft-07/schema#",
134-
# "type": "array",
135-
# "items": {
136-
# "type": "object",
137-
# "properties": {
138-
# "id": {"type": "string"},
139-
# "function": {
140-
# "type": "object", # Add this line
141-
# "oneOf": function_schemas
142-
# },
143-
# "type": {"type": "string", "enum": ["function"]}
144-
# },
145-
# "required": ["id", "function", "type"]
146-
# }
147-
# }
148-
149-
# print(json.dumps(tool_call_schema, indent=2))
150-
# return tool_call_schema
87+
def generate_model_item_schema(num_functions: int) -> Dict:
88+
function_refs = [{"$ref": f"#/$defs/Function{i + 1}" if i > 0 else "#/$defs/Function"} for i in range(num_functions)]
89+
90+
return {
91+
"type": "object",
92+
"properties": {
93+
"id": {"type": "string"},
94+
"function": {
95+
"anyOf": function_refs
96+
},
97+
"type": {"$ref": "#/$defs/Type"}
98+
},
99+
"required": ["id", "function", "type"]
100+
}

0 commit comments

Comments
 (0)