From c537e7a717880526531f9a2040c34dbce4276b41 Mon Sep 17 00:00:00 2001 From: LuoChen Date: Sat, 10 Aug 2024 12:28:29 +0800 Subject: [PATCH] redesign: use @easy_sync.sync_compatible to support sync & async --- src/ai_powered/chat_bot.py | 18 +++--- src/ai_powered/decorators.py | 10 +-- src/ai_powered/llm/adapter_selector.py | 16 +++-- .../llm/adapters/generic_adapter.py | 47 +++----------- src/ai_powered/llm/connection.py | 61 +++++++++++++++++++ src/ai_powered/llm/definitions.py | 8 +-- src/ai_powered/utils/function_wraps.py | 18 ++++++ test/examples/chat_bot/simple_chatbot.py | 4 +- test/examples/chat_bot/use_calculator.py | 4 +- test/examples/chat_bot/use_google_search.py | 2 +- test/poc/copy_type_hints.py | 32 ++++++++++ 11 files changed, 152 insertions(+), 68 deletions(-) create mode 100644 src/ai_powered/llm/connection.py create mode 100644 src/ai_powered/utils/function_wraps.py create mode 100644 test/poc/copy_type_hints.py diff --git a/src/ai_powered/chat_bot.py b/src/ai_powered/chat_bot.py index 813ace0..3a43eb8 100644 --- a/src/ai_powered/chat_bot.py +++ b/src/ai_powered/chat_bot.py @@ -1,14 +1,16 @@ from dataclasses import dataclass, field from typing import Any, ClassVar import openai +from easy_sync import sync_compatible from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from ai_powered.colors import gray from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES from ai_powered.llm.known_models import complete_model_config +from ai_powered.llm.connection import LlmConnection from ai_powered.tool_call import ChatCompletionToolParam, MakeTool -default_client = openai.OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY) +default_connection = LlmConnection(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY) model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES) @dataclass @@ -17,7 +19,7 @@ class ChatBot: system_prompt : ClassVar[str] = "" # if not empty, it will prepend to the conversation tools: ClassVar[tuple[MakeTool[..., Any], ...]] = () - client: ClassVar[openai.OpenAI] = default_client + connection: ClassVar[LlmConnection] = default_connection conversation : list[ChatCompletionMessageParam] = field(default_factory=lambda:[]) def __post_init__(self): @@ -25,11 +27,12 @@ def __post_init__(self): self._tool_dict = {tool.fn.__name__: tool for tool in self.tools} self._tool_schemas : list[ChatCompletionToolParam] | openai.NotGiven = [ t.schema() for t in self.tools ] if len(self.tools) > 0 else openai.NOT_GIVEN - def chat_continue(self) -> str: + @sync_compatible + async def chat_continue(self) -> str: if DEBUG: print(gray(f"{self.conversation =}")) - response = self.client.chat.completions.create( + response = await self.connection.chat_completions( model = model_config.model_name, messages = [*self._system_prompt, *self.conversation], tools = self._tool_schemas, @@ -48,12 +51,13 @@ def chat_continue(self) -> str: function_message = using_tool.call(tool_call) #type: ignore #TODO: async & parrallel self.conversation.append(function_message) - return self.chat_continue() + return await self.chat_continue() else: message_content = assistant_message.content assert message_content is not None return message_content - def chat(self, message: str) -> str: + @sync_compatible + async def chat(self, message: str) -> str: self.conversation.append({"role": "user", "content": message}) - return self.chat_continue() + return await self.chat_continue() diff --git a/src/ai_powered/decorators.py b/src/ai_powered/decorators.py index 9a991f9..0dbd6cb 100644 --- a/src/ai_powered/decorators.py +++ b/src/ai_powered/decorators.py @@ -6,6 +6,7 @@ import json import msgspec +from ai_powered.llm.connection import LlmConnection from ai_powered.llm.definitions import ModelFeature from ai_powered.llm.adapter_selector import FunctionSimulatorSelector from ai_powered.llm.known_models import complete_model_config @@ -56,8 +57,7 @@ def ai_powered(fn : Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, A print(f"{param_name} (json schema): {schema}") print(f"return (json schema): {return_schema}") - client = openai.OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) - async_client = openai.AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) + connection = LlmConnection(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES) model_name = model_config.model_name model_features: set[ModelFeature] = model_config.supported_features @@ -77,7 +77,7 @@ def ai_powered(fn : Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, A fn_simulator = FunctionSimulatorSelector( function_name, f"{sig}", docstring, parameters_schema, return_schema, - client, async_client, model_name, model_features, model_options + connection, model_name, model_features, model_options ) if DEBUG: @@ -92,7 +92,7 @@ def wrapper_fn(*args: P.args, **kwargs: P.kwargs) -> R: if DEBUG: print(f"{real_arg_str =}") - resp_str = fn_simulator.query_model(real_arg_str) + resp_str = fn_simulator.query_model(real_arg_str).wait() if DEBUG: print(f"{resp_str =}") @@ -114,7 +114,7 @@ async def wrapper_fn_async(*args: P.args, **kwargs: P.kwargs) -> R: print(f"{real_arg_str =}") # NOTE: the main logic - resp_str = await fn_simulator.query_model_async(real_arg_str) + resp_str = await fn_simulator.query_model(real_arg_str) if DEBUG: print(f"{resp_str =}") diff --git a/src/ai_powered/llm/adapter_selector.py b/src/ai_powered/llm/adapter_selector.py index d867db0..7a15755 100644 --- a/src/ai_powered/llm/adapter_selector.py +++ b/src/ai_powered/llm/adapter_selector.py @@ -1,3 +1,4 @@ +from easy_sync import sync_compatible from ai_powered.llm.adapters.generic_adapter import GenericFunctionSimulator from ai_powered.llm.adapters.tools_adapter import ToolsFunctionSimulator from ai_powered.llm.adapters.chat_adapter import ChatFunctionSimulator @@ -14,26 +15,23 @@ def _select_impl(self) -> GenericFunctionSimulator: if ModelFeature.structured_outputs in self.model_features: return StructuredOutputFunctionSimulator( self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema, - self.client, self.async_client, self.model_name, self.model_features, self.model_options + self.connection, self.model_name, self.model_features, self.model_options ) elif ModelFeature.tools in self.model_features: return ToolsFunctionSimulator( self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema, - self.client, self.async_client, self.model_name, self.model_features, self.model_options + self.connection, self.model_name, self.model_features, self.model_options ) else: return ChatFunctionSimulator( self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema, - self.client, self.async_client, self.model_name, self.model_features, self.model_options + self.connection, self.model_name, self.model_features, self.model_options ) def __post_init__(self): super().__post_init__() self._selected_impl = self._select_impl() - def query_model(self, arguments_json: str) -> str: - return self._selected_impl.query_model(arguments_json) - - async def query_model_async(self, arguments_json: str) -> str: - result = await self._selected_impl.query_model_async(arguments_json) - return result + @sync_compatible + async def query_model(self, arguments_json: str) -> str: + return await self._selected_impl.query_model(arguments_json) diff --git a/src/ai_powered/llm/adapters/generic_adapter.py b/src/ai_powered/llm/adapters/generic_adapter.py index 9c06dd7..217ca2a 100644 --- a/src/ai_powered/llm/adapters/generic_adapter.py +++ b/src/ai_powered/llm/adapters/generic_adapter.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field import json from typing import Any, Iterable, Set +from easy_sync import sync_compatible import openai from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam @@ -9,6 +10,7 @@ from openai.types.chat.completion_create_params import ResponseFormat from ai_powered.colors import green, red, yellow from ai_powered.constants import DEBUG, SYSTEM_PROMPT +from ai_powered.llm.connection import LlmConnection from ai_powered.llm.definitions import FunctionSimulator, ModelFeature from ai_powered.tool_call import ChatCompletionToolParam @@ -16,8 +18,7 @@ class GenericFunctionSimulator (FunctionSimulator, ABC): ''' implementation of FunctionSimulator for OpenAI compatible models ''' - client: openai.OpenAI - async_client: openai.AsyncOpenAI + connection: LlmConnection model_name: str model_features: Set[ModelFeature] model_options: dict[str, Any] @@ -58,9 +59,10 @@ def _param_tool_choice_maker(self) -> ChatCompletionToolChoiceOptionParam | open ''' to be overrided ''' return openai.NOT_GIVEN - def _chat_completion_query(self, arguments_json: str) -> ChatCompletion: + @sync_compatible + async def _chat_completion_query(self, arguments_json: str) -> ChatCompletion: ''' default impl is provided ''' - return self.client.chat.completions.create( + return await self.connection.chat_completions( model = self.model_name, messages = [ {"role": "system", "content": self.system_prompt}, @@ -71,20 +73,6 @@ def _chat_completion_query(self, arguments_json: str) -> ChatCompletion: response_format=self._param_response_format, ) - async def _chat_completion_query_async(self, arguments_json: str) -> ChatCompletion: - ''' default impl is provided ''' - result = await self.async_client.chat.completions.create( - model = self.model_name, - messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": arguments_json} - ], - tools = self._param_tools, - tool_choice = self._param_tool_choice, - response_format=self._param_response_format, - ) - return result - def _response_message_parser(self, response_message: ChatCompletionMessage) -> str: ''' to be overrided ''' if DEBUG: @@ -92,14 +80,15 @@ def _response_message_parser(self, response_message: ChatCompletionMessage) -> s raise NotImplementedError #@override - def query_model(self, arguments_json: str) -> str: + @sync_compatible + async def query_model(self, arguments_json: str) -> str: if DEBUG: print(yellow(f"{arguments_json =}")) print(yellow(f"request.tools = {self._param_tools}")) print(green(f"[fn {self.function_name}] request prepared.")) - response = self._chat_completion_query(arguments_json) + response = await self._chat_completion_query(arguments_json) if DEBUG: print(yellow(f"{response =}")) @@ -108,21 +97,3 @@ def query_model(self, arguments_json: str) -> str: response_message = response.choices[0].message result_str = self._response_message_parser(response_message) return result_str - - #@override - async def query_model_async(self, arguments_json: str) -> str: - - if DEBUG: - print(yellow(f"{arguments_json =}")) - print(yellow(f"request.tools = {self._param_tools}")) - print(green(f"[fn {self.function_name}] request prepared.")) - - response = await self._chat_completion_query_async(arguments_json) - - if DEBUG: - print(yellow(f"[query_model_async()] {response =}")) - print(green(f"[fn {self.function_name}] response received.")) - - response_message = response.choices[0].message - result_str = self._response_message_parser(response_message) - return result_str diff --git a/src/ai_powered/llm/connection.py b/src/ai_powered/llm/connection.py new file mode 100644 index 0000000..112f224 --- /dev/null +++ b/src/ai_powered/llm/connection.py @@ -0,0 +1,61 @@ +from functools import partial +from typing import Any, Mapping, Union +from easy_sync import Waitable, sync_compatible +import httpx +import openai +from openai.types.chat.chat_completion import ChatCompletion +from ai_powered.utils.function_wraps import wraps_arguments_type + + +class LlmConnection: + sync_client: openai.OpenAI + async_client: openai.AsyncOpenAI + base_url: str | httpx.URL | None + + def __init__(self, + api_key: str | None = None, + organization: str | None = None, + project: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, httpx.Timeout, None, openai.NotGiven] = openai.NOT_GIVEN, + max_retries: int = openai.DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + sync_http_client: httpx.Client | None = None, + async_http_client: httpx.AsyncClient | None = None, + ): + self.base_url = base_url + + self.sync_client = openai.OpenAI( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=sync_http_client, + ) + + self.async_client = openai.AsyncOpenAI( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=async_http_client, + ) + + async_fn = partial(self.async_client.chat.completions.create, stream=False) + sync_fn = partial(self.sync_client.chat.completions.create, stream=False) + + f = sync_compatible(sync_fn=sync_fn)(async_fn) #type: ignore + self._chat_completions = f + + @wraps_arguments_type(openai.AsyncOpenAI().chat.completions.create) + def chat_completions(self, *args: list[Any], **kwargs: dict[str, Any]) -> Waitable[ChatCompletion]: + return self._chat_completions(*args, **kwargs) diff --git a/src/ai_powered/llm/definitions.py b/src/ai_powered/llm/definitions.py index a576792..74c8feb 100644 --- a/src/ai_powered/llm/definitions.py +++ b/src/ai_powered/llm/definitions.py @@ -3,6 +3,8 @@ import enum from typing import Any, Optional +from easy_sync import sync_compatible + class ModelFeature (enum.Enum): ''' Ollama Doc: https://ollama.fan/reference/openai/#supported-features @@ -31,8 +33,6 @@ class FunctionSimulator (ABC): parameters_schema: dict[str, Any] return_schema: dict[str, Any] - def query_model(self, arguments_json: str) -> str: - ... - - async def query_model_async(self, arguments_json: str) -> str: + @sync_compatible + async def query_model(self, arguments_json: str) -> str: ... diff --git a/src/ai_powered/utils/function_wraps.py b/src/ai_powered/utils/function_wraps.py new file mode 100644 index 0000000..15e5777 --- /dev/null +++ b/src/ai_powered/utils/function_wraps.py @@ -0,0 +1,18 @@ +from typing import Any, Callable, TypeVar +from typing_extensions import ParamSpec + + +P = ParamSpec('P') +R = TypeVar('R') + + +def wraps(origin_fn: Callable[P, R]) -> Callable[ [Callable[..., Any]], Callable[P, R]]: + def wrapper(fn: Callable[..., Any]) -> Callable[P, R]: + return fn #type: ignore + return wrapper + + +def wraps_arguments_type(origin_fn: Callable[P, Any]) -> Callable[ [Callable[..., R]], Callable[P, R]]: + def wrapper(fn: Callable[..., R]) -> Callable[P, R]: + return fn #type: ignore + return wrapper diff --git a/test/examples/chat_bot/simple_chatbot.py b/test/examples/chat_bot/simple_chatbot.py index ca0158c..933c5bf 100644 --- a/test/examples/chat_bot/simple_chatbot.py +++ b/test/examples/chat_bot/simple_chatbot.py @@ -3,6 +3,6 @@ def test_simple_chatbot(): bot = ChatBot() - print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4'))) - print(green(bot.chat('and what is above result divided by 2?'))) + print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4').wait())) + print(green(bot.chat('and what is above result divided by 2?').wait())) print(gray(f"{bot.conversation}")) diff --git a/test/examples/chat_bot/use_calculator.py b/test/examples/chat_bot/use_calculator.py index 082eb30..824879d 100644 --- a/test/examples/chat_bot/use_calculator.py +++ b/test/examples/chat_bot/use_calculator.py @@ -19,6 +19,6 @@ class MyChatBot (ChatBot): def test_use_calculator(): bot = MyChatBot() - print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4'))) - print(green(bot.chat('and what is above result divided by 2?'))) + print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4').wait())) + print(green(bot.chat('and what is above result divided by 2?').wait())) print(gray(f"{bot.conversation}")) diff --git a/test/examples/chat_bot/use_google_search.py b/test/examples/chat_bot/use_google_search.py index 3111deb..e51be97 100644 --- a/test/examples/chat_bot/use_google_search.py +++ b/test/examples/chat_bot/use_google_search.py @@ -10,5 +10,5 @@ class MyChatBot (ChatBot): def test_use_google_search(): bot = MyChatBot() - print(green(bot.chat("what's USD price in CNY today?"))) + print(green(bot.chat("what's USD price in CNY today?").wait())) print(gray(f"{bot.conversation}")) diff --git a/test/poc/copy_type_hints.py b/test/poc/copy_type_hints.py new file mode 100644 index 0000000..344b0be --- /dev/null +++ b/test/poc/copy_type_hints.py @@ -0,0 +1,32 @@ +import inspect +from typing import Any, get_type_hints + +# 假设这是你的函数 g,带有复杂的参数类型签名 +def g(a: int, b: str, c: float) -> bool: + # 函数体 + return True + +# 获取函数 g 的类型提示 +# 获取函数 g 的签名 +sig_g = inspect.signature(g) + + +# 定义函数 f,复用 g 的参数类型定义 +def f(*args: Any, **kwargs: Any): + # 调用 g 并进行后处理 + result = g(*args, **kwargs) + # 后处理逻辑 + processed_result = not result # 示例后处理逻辑 + return processed_result + +# 使用函数注解来复用 g 的参数类型定义 +f.__annotations__ = { + name: param.annotation + for name, param in sig_g.parameters.items() +} +f.__annotations__['return'] = sig_g.return_annotation # 设置返回类型 + +# 现在函数 f 的参数类型定义与 g 相同 +print(get_type_hints(f)) + +f()