Skip to content

Commit 211cf1b

Browse files
committed
Fixed some bugs, optimized the code structure, and
added control over LLM model selection.
1 parent e21c519 commit 211cf1b

File tree

2 files changed

+392
-279
lines changed

2 files changed

+392
-279
lines changed

Diff for: _aichat.py

+74-30
Original file line numberDiff line numberDiff line change
@@ -9,76 +9,120 @@
99
# ark:
1010
# claude:
1111

12-
class LLM_ai:
13-
_MODELS = {
14-
"zhipuai": {"url": "", # SDK use default
12+
# model info: model name, max tokens, llm name
13+
class AI_models:
14+
_MODEL = {
15+
"glm-4-flash": {"llm": "zhipuai",
1516
"model": "glm-4-flash",
16-
"key": "zhipuai-key",
1717
"max_tokens": 8000,
1818
},
19-
"kimi": {"url": "https://api-sg.moonshot.ai/v1", # use openAI SDK
19+
"glm-4-5020": {"llm": "zhipuai",
20+
"model": "glm-4-5020",
21+
"max_tokens": 8000,
22+
},
23+
"kimi8k": {"llm": "kimi", # use openAI SDK
2024
"model": "moonshot-v1-8k",
21-
"key": "sk-kimi-key",
2225
"max_tokens": 8000,
2326
},
24-
"ark": {"url": "https://ark.cn-beijing.volces.com/api/v3", # 火山方舟大模型,抖音,豆包,扣子是一家
25-
"model": "ep-20240929221043-jsbgc",
26-
"key": "ark-key",
27+
"doubao4": {"llm": "ark", # 火山方舟大模型,抖音,豆包,扣子是一家
28+
"model": "ep-20241005223718-nl742", # Doubao-pro-4k
2729
"max_tokens": 4000,
2830
},
29-
"claude": {"url": "https://api.gptapi.us/v1/chat/completions", # Legend's testing bed
31+
"doubao32": {"llm": "ark",
32+
"model": "ep-20240929221043-jsbgc", # Doubao-pro-32k
33+
"max_tokens": 4000,
34+
},
35+
"claude35": {"llm": "claude", # Legend's testing bed
3036
"model": "claude-3-5-sonnet",
31-
"key": "sk-claud-key",
3237
"max_tokens": 8000,
3338
},
3439
}
40+
41+
_working_model = {
42+
"llm" : "",
43+
"model" : "",
44+
"max_tokens" : 0
45+
}
46+
47+
# nickname 1) key used in _MODEL; 2)llm.model.9999 for dedicated llm-model, bad format will raise runtime
48+
def __init__(self, model_nickname: str):
49+
if model_nickname in self._MODEL:
50+
self._working_model = self._MODEL[model_nickname]
51+
else:
52+
llm, model, n = model_nickname.split('.')
53+
self._working_model = {
54+
"llm" : llm,
55+
"model" : model,
56+
"max_tokens" : int(n)
57+
}
58+
59+
def __getattr__(self, attr:str):
60+
return self._working_model[attr]
61+
62+
63+
# maintain llm-url, get api-key, issue chat
64+
class LLM_ai:
65+
66+
_LLM = {
67+
"zhipuai": {"url": "", # SDK use default
68+
},
69+
"kimi": {"url": "https://api-sg.moonshot.ai/v1", # use openAI SDK
70+
},
71+
"ark": {"url": "https://ark.cn-beijing.volces.com/api/v3", # 火山方舟大模型,抖音,豆包,扣子是一家
72+
},
73+
"claude": {"url": "https://api.gptapi.us/v1/chat/completions", # Legend's testing bed
74+
},
75+
}
76+
77+
# working one
3578
_llm = None
79+
_model = None
80+
_max_tokens = None
3681
_client = None
3782

3883

39-
def __init__(self, llm: str, model=""):
84+
def __init__(self, llm: str, model="", max_tokens= 2000):
4085
self._llm = llm
41-
self._MODELS[llm]["key"] = st.secrets.ai_keys[llm]
86+
ai_key = st.secrets.ai_keys[llm]
4287

4388
if llm == "zhipuai":
44-
self._client = ZhipuAI(api_key=self._MODELS[llm]["key"])
89+
self._client = ZhipuAI(api_key=ai_key) # zhipuai SDK use default url
4590
elif llm == "kimi":
46-
self._client = openai.OpenAI(api_key=self._MODELS[llm]["key"],
47-
base_url = self._MODELS[llm]["url"])
91+
self._client = openai.OpenAI(api_key=ai_key, base_url = self._LLM[llm]["url"])
4892
elif llm == "ark":
49-
self._client = Ark(api_key=self._MODELS[llm]["key"],
50-
base_url=self._MODELS[llm]["url"])
93+
self._client = Ark(api_key=ai_key, base_url=self._LLM[llm]["url"])
5194
elif llm == "claude":
52-
self._client = openai.OpenAI(api_key=self._MODELS[llm]["key"],
53-
base_url=self._MODELS[llm]["url"])
95+
self._client = openai.OpenAI(api_key=ai_key, base_url=self._LLM[llm]["url"])
5496
else:
5597
raise ValueError(f"Invalid llm: {llm}")
5698

5799
if not model == "":
58-
self._MODELS[llm]["model"] = model
100+
self._model = model
101+
self._max_tokens = max_tokens
59102

60103
# use stream model if pass_chunk is not None
61104
def chat(self, prompt: str, t: str, pass_chunk= None):
62105
if pass_chunk is None:
63106
response = self._client.chat.completions.create(
64-
model=self._MODELS[self._llm]["model"],
107+
model = self._model,
65108
messages = [
66109
{"role": "system", "content": prompt},
67110
{"role": "user", "content": t}
68111
],
69112
temperature = 0.3,
113+
max_tokens=self._max_tokens,
70114
)
71115
return response.choices[0].message.content
72116
else:
73117
response = self._client.chat.completions.create(
74-
model=self._MODELS[self._llm]["model"],
75-
messages=[
76-
{"role": "system", "content": prompt},
77-
{"role": "user", "content": t},
78-
],
79-
temperature=0.3,
80-
stream=True,
81-
max_tokens=self._MODELS[self._llm]["max_tokens"],
118+
model = self._model,
119+
messages = [
120+
{"role": "system", "content": prompt},
121+
{"role": "user", "content": t},
122+
],
123+
temperature=0.3,
124+
stream=True,
125+
max_tokens=self._max_tokens,
82126
)
83127

84128
answer=""

0 commit comments

Comments
 (0)