Skip to content

Commit 6a68a2c

Browse files
committed
fix: set max tokens
1 parent bf14762 commit 6a68a2c

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

main.go

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/*
22
* @Author: Vincent Yang
33
* @Date: 2024-04-16 22:58:22
4-
* @LastEditors: Vincent Yang
5-
* @LastEditTime: 2024-04-18 20:13:27
4+
* @LastEditors: Vincent Young
5+
* @LastEditTime: 2024-04-19 03:45:05
66
* @FilePath: /cohere2openai/main.go
77
* @Telegram: https://t.me/missuo
88
* @GitHub: https://github.com/missuo
@@ -45,6 +45,7 @@ func cohereRequest(c *gin.Context, openAIReq OpenAIRequest) {
4545
ChatHistory: []ChatMessage{},
4646
Message: "",
4747
Stream: openAIReq.Stream,
48+
MaxTokens: openAIReq.MaxTokens,
4849
}
4950

5051
for _, msg := range openAIReq.Messages {
@@ -67,6 +68,7 @@ func cohereRequest(c *gin.Context, openAIReq OpenAIRequest) {
6768
}
6869

6970
reqBody, _ := json.Marshal(cohereReq)
71+
fmt.Println(string(reqBody))
7072
req, err := http.NewRequest("POST", "https://api.cohere.ai/v1/chat", bytes.NewBuffer(reqBody))
7173
if err != nil {
7274
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -180,6 +182,7 @@ func cohereNonStreamRequest(c *gin.Context, openAIReq OpenAIRequest) {
180182
ChatHistory: []ChatMessage{},
181183
Message: "",
182184
Stream: openAIReq.Stream,
185+
MaxTokens: openAIReq.MaxTokens,
183186
}
184187

185188
for _, msg := range openAIReq.Messages {
@@ -261,6 +264,25 @@ func handler(c *gin.Context) {
261264
if !isInSlice(openAIReq.Model, allowModels) {
262265
openAIReq.Model = "command-r-plus"
263266
}
267+
268+
// Set max tokens based on model
269+
switch openAIReq.Model {
270+
case "command-light":
271+
openAIReq.MaxTokens = 4000
272+
case "command":
273+
openAIReq.MaxTokens = 4000
274+
case "command-light-nightly":
275+
openAIReq.MaxTokens = 4000
276+
case "command-nightly":
277+
openAIReq.MaxTokens = 4000
278+
case "command-r":
279+
openAIReq.MaxTokens = 4000
280+
case "command-r-plus":
281+
openAIReq.MaxTokens = 4000
282+
default:
283+
openAIReq.MaxTokens = 4096
284+
}
285+
264286
if openAIReq.Stream {
265287
cohereRequest(c, openAIReq)
266288
} else {

types.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/*
22
* @Author: Vincent Yang
33
* @Date: 2024-04-16 22:58:27
4-
* @LastEditors: Vincent Yang
5-
* @LastEditTime: 2024-04-18 04:34:55
4+
* @LastEditors: Vincent Young
5+
* @LastEditTime: 2024-04-19 03:34:12
66
* @FilePath: /cohere2openai/types.go
77
* @Telegram: https://t.me/missuo
88
* @GitHub: https://github.com/missuo
@@ -18,14 +18,16 @@ type OpenAIRequest struct {
1818
Role string `json:"role"`
1919
Content string `json:"content"`
2020
} `json:"messages"`
21-
Stream bool `json:"stream"`
21+
Stream bool `json:"stream"`
22+
MaxTokens int64 `json:"max_tokens"`
2223
}
2324

2425
type CohereRequest struct {
2526
Model string `json:"model"`
2627
ChatHistory []ChatMessage `json:"chat_history"`
2728
Message string `json:"message"`
2829
Stream bool `json:"stream"`
30+
MaxTokens int64 `json:"max_tokens"`
2931
}
3032

3133
type ChatMessage struct {

0 commit comments

Comments
 (0)