Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit 4ab3642

Browse files
committed
功能: 实现外部API服务和设置对话框更新
- 在`PredictOption`结构体中添加`Stop`字段 - 添加一个新文件`external_api_service.go`,其中包含`NLP`接口的实现 - 添加一个名为`isExternalApi`的函数,用于检查文件是否为外部API - 修改`modelLoad`,如果文件是外部API,则使用`ExternalApiService` - 在`index.html`中的设置对话框中添加一个`external_api`字段
1 parent da259f1 commit 4ab3642

File tree

5 files changed

+198
-19
lines changed

5 files changed

+198
-19
lines changed

common.go

+1
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,6 @@ type PredictOption struct {
109109
Tokens int
110110
MaxTokens int
111111
Threads int
112+
Stop []string
112113
StreamFn func(outputText string) (stop bool)
113114
}

external_api_service.go

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"strings"
7+
8+
"github.com/parnurzeal/gorequest"
9+
"github.com/tidwall/gjson"
10+
"github.com/wailovet/nuwa"
11+
)
12+
13+
type ExternalApiService struct {
14+
Url string
15+
}
16+
17+
// Free implements NLP
18+
func (e *ExternalApiService) Free() {
19+
20+
}
21+
22+
// IsReady implements NLP
23+
func (e *ExternalApiService) IsReady() bool {
24+
return true
25+
}
26+
27+
// ModelFile implements NLP
28+
func (e *ExternalApiService) ModelFile() string {
29+
return e.Url
30+
}
31+
32+
// Predict implements NLP
33+
func (e *ExternalApiService) Predict(p Prompts, his ChatHistory, opts *PredictOption) (string, error) {
34+
text := fmt.Sprintln(p.Instruct)
35+
36+
if his == nil || len(his) == 0 {
37+
38+
} else {
39+
if his[len(his)-1].Role == "assistant" {
40+
if len(his)-1 > 0 {
41+
his = his[:len(his)-1]
42+
} else {
43+
return "", fmt.Errorf("history is nil")
44+
}
45+
}
46+
47+
for _, v := range his {
48+
if v.Role == "assistant" {
49+
text += fmt.Sprintln(p.AssistantPrefix, v.Content)
50+
} else {
51+
text += fmt.Sprintln(p.UserPrefix, v.Content)
52+
}
53+
}
54+
text += p.AssistantPrefix
55+
}
56+
57+
log.Println(text)
58+
59+
req := gorequest.New()
60+
61+
if !strings.HasSuffix(e.Url, "/") {
62+
e.Url += "/"
63+
}
64+
sendUrl := e.Url + "completion"
65+
66+
postJson := map[string]interface{}{
67+
"prompt": text,
68+
"batch_size": 64,
69+
"as_loop": true,
70+
"n_keep": -1,
71+
"interactive": true,
72+
"stop": []string{"\n### Human:"},
73+
}
74+
if opts != nil {
75+
if opts.BatchSize > 0 {
76+
postJson["batch_size"] = opts.BatchSize
77+
}
78+
if opts.Penalty > 0 {
79+
postJson["penalty"] = opts.Penalty
80+
}
81+
if opts.Temperature > 0 {
82+
postJson["temperature"] = opts.Temperature
83+
}
84+
if opts.TopP > 0 {
85+
postJson["top_p"] = opts.TopP
86+
}
87+
if opts.TopK > 0 {
88+
postJson["top_k"] = opts.TopK
89+
}
90+
if opts.Tokens > 0 {
91+
postJson["n_predict"] = opts.Tokens
92+
}
93+
if opts.Threads > 0 {
94+
postJson["threads"] = opts.Threads
95+
}
96+
if opts.Stop != nil {
97+
postJson["stop"] = opts.Stop
98+
}
99+
} else {
100+
return "", fmt.Errorf("opts is nil")
101+
}
102+
103+
req.Post(sendUrl).Send(nuwa.Helper().JsonEncode(postJson)).End()
104+
105+
nextTokenUrl := e.Url + "next-token"
106+
107+
message := ""
108+
isStop := false
109+
for {
110+
var errs []error
111+
var result string
112+
if isStop {
113+
_, result, errs = req.Get(nextTokenUrl + "?stop=true").End()
114+
} else {
115+
116+
_, result, errs = req.Get(nextTokenUrl).End()
117+
}
118+
if errs != nil {
119+
return "", errs[0]
120+
}
121+
// message += result.data.content;
122+
// if (result.data.stop) {
123+
// console.log("Completed");
124+
// // make sure to add the completion to the prompt.
125+
// prompt += `### Assistant: ${message}`;
126+
// break;
127+
// }
128+
message += gjson.Get(result, "content").String()
129+
if gjson.Get(result, "stop").Bool() {
130+
break
131+
}
132+
133+
log.Println("web.result:", result)
134+
135+
if !isStop {
136+
isStop = opts.StreamFn(message)
137+
}
138+
}
139+
140+
return message, nil
141+
}
142+
143+
// StartUp implements NLP
144+
func (e *ExternalApiService) StartUp(modelfile string) error {
145+
e.Url = modelfile
146+
return nil
147+
}

main.go

+34-15
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,41 @@ var wsConnMapLock sync.RWMutex //ws连接池锁
3939

4040
var adminHome, _ = os.UserHomeDir() //用户目录
4141

42+
func isExternalApi(filename string) bool {
43+
if strings.HasPrefix(filename, "http://") || strings.HasPrefix(filename, "https://") {
44+
return true
45+
}
46+
return false
47+
}
48+
4249
func modelLoad(filename string) error {
4350
if lm != nil {
4451
lm.Free()
4552
}
4653

47-
_modelType := modelType(filename)
54+
if isExternalApi(filename) {
55+
lm = &ExternalApiService{}
56+
lm.StartUp(filename)
57+
} else {
58+
_modelType := modelType(filename)
4859

49-
log.Println("modelType:", _modelType)
60+
log.Println("modelType:", _modelType)
5061

51-
switch _modelType {
52-
case "llama":
53-
lm = &Llama{}
54-
case "rwkv":
55-
lm = &RWKV{}
56-
default:
57-
lm = &Llama{}
58-
}
62+
switch _modelType {
63+
case "llama":
64+
lm = &Llama{}
65+
case "rwkv":
66+
lm = &RWKV{}
67+
default:
68+
lm = &Llama{}
69+
}
5970

60-
err := lm.StartUp(filename)
61-
if err != nil {
62-
return err
71+
err := lm.StartUp(filename)
72+
if err != nil {
73+
return err
74+
}
6375
}
76+
6477
return nil
6578
}
6679

@@ -88,11 +101,16 @@ func serviceStartUp() {
88101

89102
nuwa.Http().HandleFunc("/model/reload", func(ctx nuwa.HttpContext) {
90103
basePath := ctx.REQUEST["base_path"]
104+
// external_api
105+
filename := ctx.REQUEST["external_api"]
106+
91107
if basePath == "" {
92108
basePath, _ = nuwa.Helper().GetCurrentPath()
93109
}
94-
filename := ctx.ParamRequired("filename")
95-
filename = filepath.Join(basePath, filename)
110+
if filename == "" {
111+
filename = ctx.REQUEST["filename"]
112+
filename = filepath.Join(basePath, filename)
113+
}
96114
err := modelLoad(filename)
97115
ctx.CheckErrDisplayByError(err)
98116
ctx.DisplayByData(filename)
@@ -254,6 +272,7 @@ func serviceStartUp() {
254272
TopP: topP,
255273
Tokens: tokens,
256274
Threads: threads,
275+
Stop: []string{stop_words},
257276
StreamFn: func(outputText string) (stop bool) {
258277
if strings.HasSuffix(outputText, stop_words) {
259278
return true

src/index.html

+10-4
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,16 @@
504504
</div>
505505

506506
<el-dialog title="设置" :visible="active_screen=='setting'" :show-close="false" center width="800px">
507-
<div style="height: 400px;width: 740px;">
507+
<div style="height: 500px;width: 740px;">
508508
<el-form ref="form" label-width="100px" :inline="true">
509+
<el-form-item label="外部api地址">
510+
<el-input size="mini" v-model="external_api" placeholder="llama.cpp服务端外部api地址"
511+
style="width: 600px;"></el-input>
512+
</el-form-item>
513+
</el-form>
514+
515+
516+
<el-form ref="form" label-width="100px" :inline="true" v-if="external_api==''">
509517
<el-form-item label="模型文件夹">
510518
<el-input size="mini" v-model="nlp_model_base_path" placeholder="留空为当前执行程序所在文件夹"
511519
style="width: 200px;"></el-input>
@@ -517,7 +525,6 @@
517525
</el-option>
518526
</el-select>
519527
</el-form-item>
520-
521528
</el-form>
522529

523530

@@ -565,8 +572,7 @@
565572
<el-form-item label="batch">
566573
<el-input-number size="mini" v-model="predict_config.batch" :step="1">
567574
</el-input-number>
568-
<el-tooltip class="item" effect="dark" content="CUDA必须要大于32才会生效"
569-
placement="top-start">
575+
<el-tooltip class="item" effect="dark" content="CUDA必须要大于32才会生效" placement="top-start">
570576
<svg t="1678602279133" class="icon" viewBox="0 0 1024 1024" version="1.1"
571577
xmlns="http://www.w3.org/2000/svg" p-id="4047" width="12" height="12">
572578
<path

src/index.js

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ function initVue() {
8787
session_id: "",
8888
load_lock: false,
8989
nlp_model: "",
90+
external_api: "",
9091
nlp_model_base_path: "",
9192
nlp_model_list: [],
9293
history: [],
@@ -138,6 +139,9 @@ function initVue() {
138139
nlp_model: function (val) {
139140
storageSetValue("config.nlp_model", val);
140141
},
142+
external_api: function (val) {
143+
storageSetValue("config.external_api", val);
144+
},
141145
predict_config: {
142146
handler: function (val) {
143147
storageSetValue("config.predict_config", JSON.stringify(val));
@@ -324,6 +328,7 @@ async function loadPage() {
324328
async function loadConfig() {
325329
app.nlp_model_base_path = await storageGetValue("config.nlp_model_base_path");
326330
app.nlp_model = await storageGetValue("config.nlp_model");
331+
app.external_api = await storageGetValue("config.external_api");
327332
let predict_config = await storageGetValue("config.predict_config");
328333
if (predict_config) {
329334
app.predict_config = JSON.parse(predict_config);
@@ -333,6 +338,7 @@ async function loadConfig() {
333338
async function loadModel() {
334339
app.load_lock = true;
335340
let res = await axios.post(`${host}/model/reload`, {
341+
external_api: app.external_api,
336342
base_path: app.nlp_model_base_path,
337343
filename: app.nlp_model,
338344
});

0 commit comments

Comments
 (0)