Skip to content

Commit 72dbe85

Browse files
添加接入 火山引擎在线大模型 内容的支持 (#2165)
* use oai adaptive bridge function to handle vol engine * add vol engine deepseek v3 --------- Co-authored-by: binary-husky <[email protected]>
1 parent 4a79aa6 commit 72dbe85

File tree

4 files changed

+143
-48
lines changed

4 files changed

+143
-48
lines changed

config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
4444
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
4545
"gemini-1.5-pro", "chatglm3", "chatglm4",
46-
"deepseek-chat", "deepseek-coder", "deepseek-reasoner"
46+
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
47+
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
4748
]
4849

4950
EMBEDDING_MODEL = "text-embedding-3-small"
@@ -267,6 +268,10 @@
267268
YIMODEL_API_KEY = ""
268269

269270

271+
# 接入火山引擎的在线大模型),api-key获取地址 https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
272+
ARK_API_KEY = "00000000-0000-0000-0000-000000000000" # 火山引擎 API KEY
273+
274+
270275
# 紫东太初大模型 https://ai-maas.wair.ac.cn
271276
TAICHU_API_KEY = ""
272277

request_llms/bridge_all.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def decode(self, *args, **kwargs):
8080
yimodel_endpoint = "https://api.lingyiwanwu.com/v1/chat/completions"
8181
deepseekapi_endpoint = "https://api.deepseek.com/v1/chat/completions"
8282
grok_model_endpoint = "https://api.x.ai/v1/chat/completions"
83+
volcengine_endpoint = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
8384

8485
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
8586
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
@@ -102,6 +103,7 @@ def decode(self, *args, **kwargs):
102103
if yimodel_endpoint in API_URL_REDIRECT: yimodel_endpoint = API_URL_REDIRECT[yimodel_endpoint]
103104
if deepseekapi_endpoint in API_URL_REDIRECT: deepseekapi_endpoint = API_URL_REDIRECT[deepseekapi_endpoint]
104105
if grok_model_endpoint in API_URL_REDIRECT: grok_model_endpoint = API_URL_REDIRECT[grok_model_endpoint]
106+
if volcengine_endpoint in API_URL_REDIRECT: volcengine_endpoint = API_URL_REDIRECT[volcengine_endpoint]
105107

106108
# 获取tokenizer
107109
tokenizer_gpt35 = LazyloadTiktoken("gpt-3.5-turbo")
@@ -954,7 +956,7 @@ def decode(self, *args, **kwargs):
954956
try:
955957
grok_beta_128k_noui, grok_beta_128k_ui = get_predict_function(
956958
api_key_conf_name="GROK_API_KEY", max_output_token=8192, disable_proxy=False
957-
)
959+
)
958960

959961
model_info.update({
960962
"grok-beta": {
@@ -1089,8 +1091,10 @@ def decode(self, *args, **kwargs):
10891091
})
10901092
except:
10911093
logger.error(trimmed_format_exc())
1094+
10921095
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
1093-
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
1096+
claude_models = ["deepseek-chat", "deepseek-coder", "deepseek-reasoner"]
1097+
if any(item in claude_models for item in AVAIL_LLM_MODELS):
10941098
try:
10951099
deepseekapi_noui, deepseekapi_ui = get_predict_function(
10961100
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
@@ -1127,6 +1131,60 @@ def decode(self, *args, **kwargs):
11271131
})
11281132
except:
11291133
logger.error(trimmed_format_exc())
1134+
1135+
# -=-=-=-=-=-=- 火山引擎 对齐支持 -=-=-=-=-=-=-
1136+
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("volcengine-")]:
1137+
# 为了更灵活地接入volcengine多模型管理界面,设计了此接口,例子:AVAIL_LLM_MODELS = ["volcengine-deepseek-r1-250120(max_token=6666)"]
1138+
# 其中
1139+
# "volcengine-" 是前缀(必要)
1140+
# "deepseek-r1-250120" 是模型名(必要)
1141+
# "(max_token=6666)" 是配置(非必要)
1142+
model_info_extend = model_info
1143+
model_info_extend.update({
1144+
"deepseek-r1-250120": {
1145+
"max_token": 16384,
1146+
"enable_reasoning": True,
1147+
"can_multi_thread": True,
1148+
"endpoint": volcengine_endpoint,
1149+
"tokenizer": tokenizer_gpt35,
1150+
"token_cnt": get_token_num_gpt35,
1151+
},
1152+
"deepseek-v3-241226": {
1153+
"max_token": 16384,
1154+
"enable_reasoning": False,
1155+
"can_multi_thread": True,
1156+
"endpoint": volcengine_endpoint,
1157+
"tokenizer": tokenizer_gpt35,
1158+
"token_cnt": get_token_num_gpt35,
1159+
},
1160+
})
1161+
try:
1162+
origin_model_name, max_token_tmp = read_one_api_model_name(model)
1163+
# 如果是已知模型,则尝试获取其信息
1164+
original_model_info = model_info_extend.get(origin_model_name.replace("volcengine-", "", 1), None)
1165+
except:
1166+
logger.error(f"volcengine模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
1167+
continue
1168+
1169+
volcengine_noui, volcengine_ui = get_predict_function(api_key_conf_name="ARK_API_KEY", max_output_token=8192, disable_proxy=True, model_remove_prefix = ["volcengine-"])
1170+
1171+
this_model_info = {
1172+
"fn_with_ui": volcengine_ui,
1173+
"fn_without_ui": volcengine_noui,
1174+
"endpoint": volcengine_endpoint,
1175+
"can_multi_thread": True,
1176+
"max_token": 64000,
1177+
"tokenizer": tokenizer_gpt35,
1178+
"token_cnt": get_token_num_gpt35,
1179+
}
1180+
1181+
# 同步已知模型的其他信息
1182+
attribute = "has_multimodal_capacity"
1183+
if original_model_info is not None and original_model_info.get(attribute, None) is not None: this_model_info.update({attribute: original_model_info.get(attribute, None)})
1184+
attribute = "enable_reasoning"
1185+
if original_model_info is not None and original_model_info.get(attribute, None) is not None: this_model_info.update({attribute: original_model_info.get(attribute, None)})
1186+
model_info.update({model: this_model_info})
1187+
11301188
# -=-=-=-=-=-=- one-api 对齐支持 -=-=-=-=-=-=-
11311189
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
11321190
# 为了更灵活地接入one-api多模型管理界面,设计了此接口,例子:AVAIL_LLM_MODELS = ["one-api-mixtral-8x7b(max_token=6666)"]

request_llms/oai_std_model_template.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def decode_chunk(chunk):
5757
finish_reason = chunk["error"]["code"]
5858
except:
5959
finish_reason = "API_ERROR"
60-
return response, reasoning_content, finish_reason
60+
return response, reasoning_content, finish_reason, str(chunk)
6161

6262
try:
6363
if chunk["choices"][0]["delta"]["content"] is not None:
@@ -122,7 +122,8 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
122122
def get_predict_function(
123123
api_key_conf_name,
124124
max_output_token,
125-
disable_proxy = False
125+
disable_proxy = False,
126+
model_remove_prefix = [],
126127
):
127128
"""
128129
为openai格式的API生成响应函数,其中传入参数:
@@ -137,6 +138,16 @@ def get_predict_function(
137138

138139
APIKEY = get_conf(api_key_conf_name)
139140

141+
def remove_prefix(model_name):
142+
# 去除模型名字的前缀,输入 volcengine-deepseek-r1-250120 会返回 deepseek-r1-250120
143+
if not model_remove_prefix:
144+
return model_name
145+
model_without_prefix = model_name
146+
for prefix in model_remove_prefix:
147+
if model_without_prefix.startswith(prefix):
148+
model_without_prefix = model_without_prefix[len(prefix):]
149+
return model_without_prefix
150+
140151
def predict_no_ui_long_connection(
141152
inputs,
142153
llm_kwargs,
@@ -164,9 +175,11 @@ def predict_no_ui_long_connection(
164175
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
165176
if inputs == "":
166177
inputs = "你好👋"
178+
179+
167180
headers, payload = generate_message(
168181
input=inputs,
169-
model=llm_kwargs["llm_model"],
182+
model=remove_prefix(llm_kwargs["llm_model"]),
170183
key=APIKEY,
171184
history=history,
172185
max_output_token=max_output_token,
@@ -302,7 +315,7 @@ def predict(
302315

303316
headers, payload = generate_message(
304317
input=inputs,
305-
model=llm_kwargs["llm_model"],
318+
model=remove_prefix(llm_kwargs["llm_model"]),
306319
key=APIKEY,
307320
history=history,
308321
max_output_token=max_output_token,

tests/test_llms.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,65 @@ def validate_path():
1111

1212

1313
validate_path() # validate path so you can run from base directory
14+
if __name__ == "__main__":
15+
# from request_llms.bridge_taichu import predict_no_ui_long_connection
16+
from request_llms.bridge_volcengine import predict_no_ui_long_connection
17+
# from request_llms.bridge_cohere import predict_no_ui_long_connection
18+
# from request_llms.bridge_spark import predict_no_ui_long_connection
19+
# from request_llms.bridge_zhipu import predict_no_ui_long_connection
20+
# from request_llms.bridge_chatglm3 import predict_no_ui_long_connection
21+
llm_kwargs = {
22+
"llm_model": "volcengine",
23+
"max_length": 4096,
24+
"top_p": 1,
25+
"temperature": 1,
26+
}
1427

15-
if "在线模型":
16-
if __name__ == "__main__":
17-
from request_llms.bridge_taichu import predict_no_ui_long_connection
18-
# from request_llms.bridge_cohere import predict_no_ui_long_connection
19-
# from request_llms.bridge_spark import predict_no_ui_long_connection
20-
# from request_llms.bridge_zhipu import predict_no_ui_long_connection
21-
# from request_llms.bridge_chatglm3 import predict_no_ui_long_connection
22-
llm_kwargs = {
23-
"llm_model": "taichu",
24-
"max_length": 4096,
25-
"top_p": 1,
26-
"temperature": 1,
27-
}
28-
29-
result = predict_no_ui_long_connection(
30-
inputs="请问什么是质子?", llm_kwargs=llm_kwargs, history=["你好", "我好!"], sys_prompt="系统"
31-
)
32-
print("final result:", result)
33-
print("final result:", result)
34-
35-
36-
if "本地模型":
37-
if __name__ == "__main__":
38-
# from request_llms.bridge_newbingfree import predict_no_ui_long_connection
39-
# from request_llms.bridge_moss import predict_no_ui_long_connection
40-
# from request_llms.bridge_jittorllms_pangualpha import predict_no_ui_long_connection
41-
# from request_llms.bridge_jittorllms_llama import predict_no_ui_long_connection
42-
# from request_llms.bridge_claude import predict_no_ui_long_connection
43-
# from request_llms.bridge_internlm import predict_no_ui_long_connection
44-
# from request_llms.bridge_deepseekcoder import predict_no_ui_long_connection
45-
# from request_llms.bridge_qwen_7B import predict_no_ui_long_connection
46-
# from request_llms.bridge_qwen_local import predict_no_ui_long_connection
47-
llm_kwargs = {
48-
"max_length": 4096,
49-
"top_p": 1,
50-
"temperature": 1,
51-
}
52-
result = predict_no_ui_long_connection(
53-
inputs="请问什么是质子?", llm_kwargs=llm_kwargs, history=["你好", "我好!"], sys_prompt=""
54-
)
55-
print("final result:", result)
28+
result = predict_no_ui_long_connection(
29+
inputs="请问什么是质子?", llm_kwargs=llm_kwargs, history=["你好", "我好!"], sys_prompt="系统"
30+
)
31+
print("final result:", result)
32+
print("final result:", result)
33+
# if "在线模型":
34+
# if __name__ == "__main__":
35+
# # from request_llms.bridge_taichu import predict_no_ui_long_connection
36+
# from request_llms.bridge_volcengine import predict_no_ui_long_connection
37+
# # from request_llms.bridge_cohere import predict_no_ui_long_connection
38+
# # from request_llms.bridge_spark import predict_no_ui_long_connection
39+
# # from request_llms.bridge_zhipu import predict_no_ui_long_connection
40+
# # from request_llms.bridge_chatglm3 import predict_no_ui_long_connection
41+
# llm_kwargs = {
42+
# "llm_model": "ep-20250222011816-5cq8z",
43+
# "max_length": 4096,
44+
# "top_p": 1,
45+
# "temperature": 1,
46+
# }
47+
48+
# result = predict_no_ui_long_connection(
49+
# inputs="请问什么是质子?", llm_kwargs=llm_kwargs, history=["你好", "我好!"], sys_prompt="系统"
50+
# )
51+
# print("final result:", result)
52+
# print("final result:", result)
53+
54+
55+
# if "本地模型":
56+
# if __name__ == "__main__":
57+
# # from request_llms.bridge_newbingfree import predict_no_ui_long_connection
58+
# # from request_llms.bridge_moss import predict_no_ui_long_connection
59+
# # from request_llms.bridge_jittorllms_pangualpha import predict_no_ui_long_connection
60+
# # from request_llms.bridge_jittorllms_llama import predict_no_ui_long_connection
61+
# # from request_llms.bridge_claude import predict_no_ui_long_connection
62+
# # from request_llms.bridge_internlm import predict_no_ui_long_connection
63+
# # from request_llms.bridge_deepseekcoder import predict_no_ui_long_connection
64+
# # from request_llms.bridge_qwen_7B import predict_no_ui_long_connection
65+
# # from request_llms.bridge_qwen_local import predict_no_ui_long_connection
66+
# llm_kwargs = {
67+
# "max_length": 4096,
68+
# "top_p": 1,
69+
# "temperature": 1,
70+
# }
71+
# result = predict_no_ui_long_connection(
72+
# inputs="请问什么是质子?", llm_kwargs=llm_kwargs, history=["你好", "我好!"], sys_prompt=""
73+
# )
74+
# print("final result:", result)
5675

0 commit comments

Comments
 (0)