forked from openai/openai-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_completions.py
264 lines (208 loc) · 8.91 KB
/
_completions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never
import pydantic
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
ParsedFunction,
ParsedChatCompletion,
ChatCompletionMessage,
ParsedFunctionToolCall,
ChatCompletionToolParam,
ParsedChatCompletionMessage,
completion_create_params,
)
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ...types.shared_params import FunctionDefinition
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_tool_call import Function
ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
_default_response_format: None = None
def validate_input_tools(
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> None:
if not is_given(tools):
return
for tool in tools:
if tool["type"] != "function":
raise ValueError(
f'Currently only `function` tool types support auto-parsing; Received `{tool["type"]}`',
)
strict = tool["function"].get("strict")
if strict is not True:
raise ValueError(
f'`{tool["function"]["name"]}` is not strict. Only `strict` function tools can be auto-parsed'
)
def parse_chat_completion(
*,
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
chat_completion: ChatCompletion | ParsedChatCompletion[object],
) -> ParsedChatCompletion[ResponseFormatT]:
if is_given(input_tools):
input_tools = [t for t in input_tools]
else:
input_tools = []
choices: list[ParsedChoice[ResponseFormatT]] = []
for choice in chat_completion.choices:
if choice.finish_reason == "length":
raise LengthFinishReasonError(completion=chat_completion)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
message = choice.message
tool_calls: list[ParsedFunctionToolCall] = []
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.type == "function":
tool_call_dict = tool_call.to_dict()
tool_calls.append(
construct_type_unchecked(
value={
**tool_call_dict,
"function": {
**cast(Any, tool_call_dict["function"]),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function=tool_call.function
),
},
},
type_=ParsedFunctionToolCall,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
else:
tool_calls.append(tool_call)
choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
value={
**choice.to_dict(),
"message": {
**message.to_dict(),
"parsed": maybe_parse_content(
response_format=response_format,
message=message,
),
"tool_calls": tool_calls,
},
},
)
)
return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
)
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
def parse_function_tool_arguments(
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
) -> object:
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
if not input_tool:
return None
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return model_parse_json(input_fn.model, function.arguments)
input_fn = cast(FunctionDefinition, input_fn)
if not input_fn.get("strict"):
return None
return json.loads(function.arguments)
def maybe_parse_content(
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
) -> ResponseFormatT | None:
if has_rich_response_format(response_format) and message.content and not message.refusal:
return _parse_content(response_format, message.content)
return None
def solve_response_format_t(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> type[ResponseFormatT]:
"""Return the runtime type for the given response format.
If no response format is given, or if we won't auto-parse the response format
then we default to `None`.
"""
if has_rich_response_format(response_format):
return response_format
return cast("type[ResponseFormatT]", _default_response_format)
def has_parseable_input(
*,
response_format: type | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> bool:
if has_rich_response_format(response_format):
return True
for input_tool in input_tools or []:
if is_parseable_tool(input_tool):
return True
return False
def has_rich_response_format(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> TypeGuard[type[ResponseFormatT]]:
if not is_given(response_format):
return False
if is_response_format_param(response_format):
return False
return True
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
return is_dict(response_format)
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return True
return cast(FunctionDefinition, input_fn).get("strict") or False
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))
if is_dataclass_like_type(response_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
return pydantic.TypeAdapter(response_format).validate_json(content)
raise TypeError(f"Unable to automatically parse response format type {response_format}")
def type_to_response_format_param(
response_format: type | completion_create_params.ResponseFormat | NotGiven,
) -> ResponseFormatParam | NotGiven:
if not is_given(response_format):
return NOT_GIVEN
if is_response_format_param(response_format):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
response_format = cast(type, response_format)
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}