Skip to content

Commit

Permalink
improve: [llm_adapter] fix support for ollama deepseek-v2 which doesn…
Browse files Browse the repository at this point in the history
…ot support function_call
  • Loading branch information
luochen1990 committed Jul 31, 2024
1 parent 37800b9 commit 21ef687
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/ai_powered/llm_adapter/generic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def query_model(self, user_msg: str) -> str:
response = client.chat.completions.create(
model = self.model_name,
messages = messages,
tools = tools,
tool_choice = {"type": "function", "function": {"name": "return_result"}},
tools = tools if "function_call" in self.model_features else openai.NOT_GIVEN,
tool_choice = {"type": "function", "function": {"name": "return_result"}} if "function_call" in self.model_features else openai.NOT_GIVEN,
)
if DEBUG:
print(yellow(f"{response =}"))
Expand All @@ -66,17 +66,18 @@ def query_model(self, user_msg: str) -> str:
raw_resp_str = resp_msg.content
assert raw_resp_str is not None

raw_resp_str_strip = raw_resp_str.strip()
# raw_resp_str = "```json\n{"result": 2}\n```"

is_markdown : Callable[[str], bool]= lambda s: s.startswith("```") and s.endswith("```")

if is_markdown(raw_resp_str):
if raw_resp_str.startswith("```json"):
unwrapped_resp_str = raw_resp_str[7:-3]
if is_markdown(raw_resp_str_strip):
if raw_resp_str_strip.startswith("```json"):
unwrapped_resp_str = raw_resp_str_strip[7:-3]
else:
unwrapped_resp_str = raw_resp_str[3:-3]
unwrapped_resp_str = raw_resp_str_strip[3:-3]
else:
unwrapped_resp_str = raw_resp_str
unwrapped_resp_str = raw_resp_str_strip

if DEBUG:
print(f"{unwrapped_resp_str =}")
Expand All @@ -91,7 +92,7 @@ def query_model(self, user_msg: str) -> str:

if DEBUG:
print(f"{raw_resp_str =}")
print(f"{is_markdown(raw_resp_str) =}")
print(f"{is_markdown(raw_resp_str_strip) =}")
print(f"{is_result(unwrapped_resp_str) =}")
print(f"{result_str =}")

Expand Down
1 change: 1 addition & 0 deletions src/ai_powered/llm_adapter/known_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def equals(s: str) -> Callable[[str], bool]:
known_model_list = [
KnownModel("gorilla-llm/gorilla-openfunctions-v2-gguf/gorilla-openfunctions-v2-q4_K_M.gguf", ALL_FEATURES),
KnownModel("lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", set()),
KnownModel("deepseek-coder-v2", set()),
]
),
]
Expand Down

0 comments on commit 21ef687

Please sign in to comment.