|
9 | 9 | # ark:
|
10 | 10 | # claude:
|
11 | 11 |
|
12 |
| -class LLM_ai: |
13 |
| - _MODELS = { |
14 |
| - "zhipuai": {"url": "", # SDK use default |
| 12 | +# model info: model name, max tokens, llm name |
| 13 | +class AI_models: |
| 14 | + _MODEL = { |
| 15 | + "glm-4-flash": {"llm": "zhipuai", |
15 | 16 | "model": "glm-4-flash",
|
16 |
| - "key": "zhipuai-key", |
17 | 17 | "max_tokens": 8000,
|
18 | 18 | },
|
19 |
| - "kimi": {"url": "https://api-sg.moonshot.ai/v1", # use openAI SDK |
| 19 | + "glm-4-5020": {"llm": "zhipuai", |
| 20 | + "model": "glm-4-5020", |
| 21 | + "max_tokens": 8000, |
| 22 | + }, |
| 23 | + "kimi8k": {"llm": "kimi", # use openAI SDK |
20 | 24 | "model": "moonshot-v1-8k",
|
21 |
| - "key": "sk-kimi-key", |
22 | 25 | "max_tokens": 8000,
|
23 | 26 | },
|
24 |
| - "ark": {"url": "https://ark.cn-beijing.volces.com/api/v3", # 火山方舟大模型,抖音,豆包,扣子是一家 |
25 |
| - "model": "ep-20240929221043-jsbgc", |
26 |
| - "key": "ark-key", |
| 27 | + "doubao4": {"llm": "ark", # 火山方舟大模型,抖音,豆包,扣子是一家 |
| 28 | + "model": "ep-20241005223718-nl742", # Doubao-pro-4k |
27 | 29 | "max_tokens": 4000,
|
28 | 30 | },
|
29 |
| - "claude": {"url": "https://api.gptapi.us/v1/chat/completions", # Legend's testing bed |
| 31 | + "doubao32": {"llm": "ark", |
| 32 | + "model": "ep-20240929221043-jsbgc", # Doubao-pro-32k |
| 33 | + "max_tokens": 4000, |
| 34 | + }, |
| 35 | + "claude35": {"llm": "claude", # Legend's testing bed |
30 | 36 | "model": "claude-3-5-sonnet",
|
31 |
| - "key": "sk-claud-key", |
32 | 37 | "max_tokens": 8000,
|
33 | 38 | },
|
34 | 39 | }
|
| 40 | + |
| 41 | + _working_model = { |
| 42 | + "llm" : "", |
| 43 | + "model" : "", |
| 44 | + "max_tokens" : 0 |
| 45 | + } |
| 46 | + |
| 47 | + # nickname 1) key used in _MODEL; 2)llm.model.9999 for dedicated llm-model, bad format will raise runtime |
| 48 | + def __init__(self, model_nickname: str): |
| 49 | + if model_nickname in self._MODEL: |
| 50 | + self._working_model = self._MODEL[model_nickname] |
| 51 | + else: |
| 52 | + llm, model, n = model_nickname.split('.') |
| 53 | + self._working_model = { |
| 54 | + "llm" : llm, |
| 55 | + "model" : model, |
| 56 | + "max_tokens" : int(n) |
| 57 | + } |
| 58 | + |
| 59 | + def __getattr__(self, attr:str): |
| 60 | + return self._working_model[attr] |
| 61 | + |
| 62 | + |
| 63 | +# maintain llm-url, get api-key, issue chat |
| 64 | +class LLM_ai: |
| 65 | + |
| 66 | + _LLM = { |
| 67 | + "zhipuai": {"url": "", # SDK use default |
| 68 | + }, |
| 69 | + "kimi": {"url": "https://api-sg.moonshot.ai/v1", # use openAI SDK |
| 70 | + }, |
| 71 | + "ark": {"url": "https://ark.cn-beijing.volces.com/api/v3", # 火山方舟大模型,抖音,豆包,扣子是一家 |
| 72 | + }, |
| 73 | + "claude": {"url": "https://api.gptapi.us/v1/chat/completions", # Legend's testing bed |
| 74 | + }, |
| 75 | + } |
| 76 | + |
| 77 | + # working one |
35 | 78 | _llm = None
|
| 79 | + _model = None |
| 80 | + _max_tokens = None |
36 | 81 | _client = None
|
37 | 82 |
|
38 | 83 |
|
39 |
| - def __init__(self, llm: str, model=""): |
| 84 | + def __init__(self, llm: str, model="", max_tokens= 2000): |
40 | 85 | self._llm = llm
|
41 |
| - self._MODELS[llm]["key"] = st.secrets.ai_keys[llm] |
| 86 | + ai_key = st.secrets.ai_keys[llm] |
42 | 87 |
|
43 | 88 | if llm == "zhipuai":
|
44 |
| - self._client = ZhipuAI(api_key=self._MODELS[llm]["key"]) |
| 89 | + self._client = ZhipuAI(api_key=ai_key) # zhipuai SDK use default url |
45 | 90 | elif llm == "kimi":
|
46 |
| - self._client = openai.OpenAI(api_key=self._MODELS[llm]["key"], |
47 |
| - base_url = self._MODELS[llm]["url"]) |
| 91 | + self._client = openai.OpenAI(api_key=ai_key, base_url = self._LLM[llm]["url"]) |
48 | 92 | elif llm == "ark":
|
49 |
| - self._client = Ark(api_key=self._MODELS[llm]["key"], |
50 |
| - base_url=self._MODELS[llm]["url"]) |
| 93 | + self._client = Ark(api_key=ai_key, base_url=self._LLM[llm]["url"]) |
51 | 94 | elif llm == "claude":
|
52 |
| - self._client = openai.OpenAI(api_key=self._MODELS[llm]["key"], |
53 |
| - base_url=self._MODELS[llm]["url"]) |
| 95 | + self._client = openai.OpenAI(api_key=ai_key, base_url=self._LLM[llm]["url"]) |
54 | 96 | else:
|
55 | 97 | raise ValueError(f"Invalid llm: {llm}")
|
56 | 98 |
|
57 | 99 | if not model == "":
|
58 |
| - self._MODELS[llm]["model"] = model |
| 100 | + self._model = model |
| 101 | + self._max_tokens = max_tokens |
59 | 102 |
|
60 | 103 | # use stream model if pass_chunk is not None
|
61 | 104 | def chat(self, prompt: str, t: str, pass_chunk= None):
|
62 | 105 | if pass_chunk is None:
|
63 | 106 | response = self._client.chat.completions.create(
|
64 |
| - model=self._MODELS[self._llm]["model"], |
| 107 | + model = self._model, |
65 | 108 | messages = [
|
66 | 109 | {"role": "system", "content": prompt},
|
67 | 110 | {"role": "user", "content": t}
|
68 | 111 | ],
|
69 | 112 | temperature = 0.3,
|
| 113 | + max_tokens=self._max_tokens, |
70 | 114 | )
|
71 | 115 | return response.choices[0].message.content
|
72 | 116 | else:
|
73 | 117 | response = self._client.chat.completions.create(
|
74 |
| - model=self._MODELS[self._llm]["model"], |
75 |
| - messages=[ |
76 |
| - {"role": "system", "content": prompt}, |
77 |
| - {"role": "user", "content": t}, |
78 |
| - ], |
79 |
| - temperature=0.3, |
80 |
| - stream=True, |
81 |
| - max_tokens=self._MODELS[self._llm]["max_tokens"], |
| 118 | + model = self._model, |
| 119 | + messages = [ |
| 120 | + {"role": "system", "content": prompt}, |
| 121 | + {"role": "user", "content": t}, |
| 122 | + ], |
| 123 | + temperature=0.3, |
| 124 | + stream=True, |
| 125 | + max_tokens=self._max_tokens, |
82 | 126 | )
|
83 | 127 |
|
84 | 128 | answer=""
|
|
0 commit comments