Skip to content

Commit 36bdcde

Browse files
authored
Merge pull request #758 from hfyydd/feature/add-new-llm
add deepseek to support
2 parents 6e01e12 + ac15075 commit 36bdcde

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

src/vanna/deepseek/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .deepseek_chat import DeepSeekChat
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
from openai import OpenAI
4+
5+
from ..base import VannaBase
6+
7+
8+
9+
# from vanna.chromadb import ChromaDB_VectorStore
10+
11+
# class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
12+
# def __init__(self, config=None):
13+
# ChromaDB_VectorStore.__init__(self, config=config)
14+
# DeepSeekChat.__init__(self, config=config)
15+
16+
# vn = DeepSeekVanna(config={"api_key": "sk-************", "model": "deepseek-chat"})
17+
18+
19+
class DeepSeekChat(VannaBase):
20+
def __init__(self, config=None):
21+
if config is None:
22+
raise ValueError(
23+
"For DeepSeek, config must be provided with an api_key and model"
24+
)
25+
if "api_key" not in config:
26+
raise ValueError("config must contain a DeepSeek api_key")
27+
28+
if "model" not in config:
29+
raise ValueError("config must contain a DeepSeek model")
30+
31+
api_key = config["api_key"]
32+
model = config["model"]
33+
self.model = model
34+
self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
35+
36+
def system_message(self, message: str) -> any:
37+
return {"role": "system", "content": message}
38+
39+
def user_message(self, message: str) -> any:
40+
return {"role": "user", "content": message}
41+
42+
def assistant_message(self, message: str) -> any:
43+
return {"role": "assistant", "content": message}
44+
45+
def generate_sql(self, question: str, **kwargs) -> str:
46+
# 使用父类的 generate_sql
47+
sql = super().generate_sql(question, **kwargs)
48+
49+
# 替换 "\_" 为 "_"
50+
sql = sql.replace("\\_", "_")
51+
52+
return sql
53+
54+
def submit_prompt(self, prompt, **kwargs) -> str:
55+
chat_response = self.client.chat.completions.create(
56+
model=self.model,
57+
messages=prompt,
58+
)
59+
60+
return chat_response.choices[0].message.content

0 commit comments

Comments
 (0)