Skip to content
This repository was archived by the owner on Nov 6, 2024. It is now read-only.

Commit

Permalink
确保模型input和output一致
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekohy committed Nov 2, 2024
1 parent a64cb54 commit 15fb603
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions api/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,17 @@ router.get(config.API_PREFIX + '/v1/models', () =>
);
router.post(config.API_PREFIX + '/v1/chat/completions', (req) => handleCompletion(req));

async function GrpcToPieces(models, message, rules, stream, temperature, top_p) {
async function GrpcToPieces(inputModel,OriginModel,message, rules, stream, temperature, top_p) {
// 在非GPT类型的模型中,temperature和top_p是无效的
// 使用系统的根证书
const credentials = grpc.credentials.createSsl();
let client,request;
if (models.includes('gpt')){
if (inputModel.includes('gpt')){
// 加载proto文件
const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition;
// 构建请求消息
request = {
models: models,
models: inputModel,
messages: [
{role: 0, message: rules}, // system
{role: 1, message: message} // user
Expand All @@ -119,7 +119,7 @@ async function GrpcToPieces(models, message, rules, stream, temperature, top_p)
const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition;
// 构建请求消息
request = {
models: models,
models: inputModel,
args: {
messages: {
unknown: 1,
Expand All @@ -132,7 +132,7 @@ async function GrpcToPieces(models, message, rules, stream, temperature, top_p)
const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials);
}
return await ConvertOpenai(client,request,models,stream);
return await ConvertOpenai(client,request,inputModel,OriginModel,stream);
}

async function messagesProcess(messages) {
Expand All @@ -159,7 +159,7 @@ async function messagesProcess(messages) {
return { rules, message };
}

async function ConvertOpenai(client,request,model,stream) {
async function ConvertOpenai(client,request,inputModel,OriginModel,stream) {
for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
try {
if (stream) {
Expand All @@ -175,13 +175,14 @@ async function ConvertOpenai(client,request,model,stream) {
call.destroy()
} else if (response_code === 200) {
let response_message
if (model.includes('gpt')) {
if (inputModel.includes('gpt')) {
response_message = response.body.message_warpper.message.message;
} else {
response_message = response.args.args.args.message;
}
// 否则,将数据块加入流中
controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, model))}\n\n`));

controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, OriginModel))}\n\n`));
} else {
controller.error(new Error(`Error: stream chunk is not success`));
controller.close()
Expand All @@ -204,12 +205,12 @@ async function ConvertOpenai(client,request,model,stream) {
let response_code = Number(call.response_code);
if (response_code === 200) {
let response_message
if (model.includes('gpt')) {
if (inputModel.includes('gpt')) {
response_message = call.body.message_warpper.message.message;
} else {
response_message = call.args.args.args.message;
}
return new Response(JSON.stringify(ChatCompletionWithModel(response_message, model)), {
return new Response(JSON.stringify(ChatCompletionWithModel(response_message, OriginModel)), {
headers: {
'Content-Type': 'application/json',
},
Expand All @@ -224,16 +225,6 @@ async function ConvertOpenai(client,request,model,stream) {
return error(500, err.message);
}

function renameIfNeeded(input) {
// 替换的正则表达式
const regex = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/;
const match = input.match(regex);
if (match) {
return `${match[1]}@${match[3]}`;
}
return input;
}

function ChatCompletionWithModel(message, model) {
return {
id: 'Chat-Nekohy',
Expand Down Expand Up @@ -279,14 +270,16 @@ async function handleCompletion(request) {
try {
// todo stream逆向接口
// 解析openai格式API请求
const { model: inputModel, messages, stream,temperature,top_p} = await request.json();
const { model: OriginModel, messages, stream,temperature,top_p} = await request.json();
const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/;
const matchInput = OriginModel.match(RegexInput);
const inputModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : OriginModel;
console.log(inputModel,messages,stream)
const model = renameIfNeeded(inputModel);
// 解析system和user/assistant消息
const { rules, message:content } = await messagesProcess(messages);
console.log(rules,content)
// 响应码,回复的消息
return await GrpcToPieces(model, content, rules, stream, temperature, top_p);
return await GrpcToPieces(inputModel,OriginModel,content, rules, stream, temperature, top_p);
} catch (err) {
return error(500, err.message);
}
Expand Down

0 comments on commit 15fb603

Please sign in to comment.