Skip to content

Commit 53e5b02

Browse files
refactor(ai): ai-proxy and ai-proxy-multi (#12030)
1 parent 2938f41 commit 53e5b02

16 files changed

+432
-563
lines changed

apisix/cli/config.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,13 @@ local _M = {
219219
"ai-prompt-decorator",
220220
"ai-prompt-guard",
221221
"ai-rag",
222+
"ai-proxy-multi",
223+
"ai-proxy",
222224
"ai-aws-content-moderation",
223225
"proxy-mirror",
224226
"proxy-rewrite",
225227
"workflow",
226228
"api-breaker",
227-
"ai-proxy",
228-
"ai-proxy-multi",
229229
"limit-conn",
230230
"limit-count",
231231
"limit-req",

apisix/plugins/ai-aws-content-moderation.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ local schema = {
7272

7373
local _M = {
7474
version = 0.1,
75-
priority = 1040, -- TODO: might change
75+
priority = 1050,
7676
name = "ai-aws-content-moderation",
7777
schema = schema,
7878
}

apisix/plugins/ai-drivers/openai-base.lua

+122-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@ local mt = {
2020
__index = _M
2121
}
2222

23+
local CONTENT_TYPE_JSON = "application/json"
24+
2325
local core = require("apisix.core")
2426
local http = require("resty.http")
2527
local url = require("socket.url")
28+
local schema = require("apisix.plugins.ai-drivers.schema")
29+
local ngx_re = require("ngx.re")
30+
31+
local ngx_print = ngx.print
32+
local ngx_flush = ngx.flush
2633

2734
local pairs = pairs
2835
local type = type
36+
local ipairs = ipairs
2937
local setmetatable = setmetatable
3038

3139

@@ -40,6 +48,26 @@ function _M.new(opts)
4048
end
4149

4250

51+
function _M.validate_request(ctx)
52+
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
53+
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
54+
return nil, "unsupported content-type: " .. ct .. ", only application/json is supported"
55+
end
56+
57+
local request_table, err = core.request.get_json_request_body_table()
58+
if not request_table then
59+
return nil, err
60+
end
61+
62+
local ok, err = core.schema.check(schema.chat_request_schema, request_table)
63+
if not ok then
64+
return nil, "request format doesn't match schema: " .. err
65+
end
66+
67+
return request_table, nil
68+
end
69+
70+
4371
function _M.request(self, conf, request_table, extra_opts)
4472
local httpc, err = http.new()
4573
if not httpc then
@@ -54,11 +82,11 @@ function _M.request(self, conf, request_table, extra_opts)
5482
end
5583

5684
local ok, err = httpc:connect({
57-
scheme = endpoint and parsed_url.scheme or "https",
58-
host = endpoint and parsed_url.host or self.host,
59-
port = endpoint and parsed_url.port or self.port,
85+
scheme = parsed_url and parsed_url.scheme or "https",
86+
host = parsed_url and parsed_url.host or self.host,
87+
port = parsed_url and parsed_url.port or self.port,
6088
ssl_verify = conf.ssl_verify,
61-
ssl_server_name = endpoint and parsed_url.host or self.host,
89+
ssl_server_name = parsed_url and parsed_url.host or self.host,
6290
pool_size = conf.keepalive and conf.keepalive_pool,
6391
})
6492

@@ -75,7 +103,7 @@ function _M.request(self, conf, request_table, extra_opts)
75103
end
76104
end
77105

78-
local path = (endpoint and parsed_url.path or self.path)
106+
local path = (parsed_url and parsed_url.path or self.path)
79107

80108
local headers = extra_opts.headers
81109
headers["Content-Type"] = "application/json"
@@ -106,7 +134,95 @@ function _M.request(self, conf, request_table, extra_opts)
106134
return nil, err
107135
end
108136

109-
return res, nil, httpc
137+
return res, nil
110138
end
111139

140+
141+
function _M.read_response(ctx, res)
142+
local body_reader = res.body_reader
143+
if not body_reader then
144+
core.log.error("AI service sent no response body")
145+
return 500
146+
end
147+
148+
local content_type = res.headers["Content-Type"]
149+
core.response.set_header("Content-Type", content_type)
150+
151+
if core.string.find(content_type, "text/event-stream") then
152+
while true do
153+
local chunk, err = body_reader() -- will read chunk by chunk
154+
if err then
155+
core.log.error("failed to read response chunk: ", err)
156+
break
157+
end
158+
if not chunk then
159+
break
160+
end
161+
162+
ngx_print(chunk)
163+
ngx_flush(true)
164+
165+
local events, err = ngx_re.split(chunk, "\n")
166+
if err then
167+
core.log.warn("failed to split response chunk [", chunk, "] to events: ", err)
168+
goto CONTINUE
169+
end
170+
171+
for _, event in ipairs(events) do
172+
if not core.string.find(event, "data:") or core.string.find(event, "[DONE]") then
173+
goto CONTINUE
174+
end
175+
176+
local parts, err = ngx_re.split(event, ":", nil, nil, 2)
177+
if err then
178+
core.log.warn("failed to split data event [", event, "] to parts: ", err)
179+
goto CONTINUE
180+
end
181+
182+
if #parts ~= 2 then
183+
core.log.warn("malformed data event: ", event)
184+
goto CONTINUE
185+
end
186+
187+
local data, err = core.json.decode(parts[2])
188+
if err then
189+
core.log.warn("failed to decode data event [", parts[2], "] to json: ", err)
190+
goto CONTINUE
191+
end
192+
193+
-- usage field is null for non-last events, null is parsed as userdata type
194+
if data and data.usage and type(data.usage) ~= "userdata" then
195+
ctx.ai_token_usage = {
196+
prompt_tokens = data.usage.prompt_tokens or 0,
197+
completion_tokens = data.usage.completion_tokens or 0,
198+
total_tokens = data.usage.total_tokens or 0,
199+
}
200+
end
201+
end
202+
203+
::CONTINUE::
204+
end
205+
return
206+
end
207+
208+
local raw_res_body, err = res:read_body()
209+
if not raw_res_body then
210+
core.log.error("failed to read response body: ", err)
211+
return 500
212+
end
213+
local res_body, err = core.json.decode(raw_res_body)
214+
if err then
215+
core.log.warn("invalid response body from ai service: ", raw_res_body, " err: ", err,
216+
", it will cause token usage not available")
217+
else
218+
ctx.ai_token_usage = {
219+
prompt_tokens = res_body.usage and res_body.usage.prompt_tokens or 0,
220+
completion_tokens = res_body.usage and res_body.usage.completion_tokens or 0,
221+
total_tokens = res_body.usage and res_body.usage.total_tokens or 0,
222+
}
223+
end
224+
return res.status, raw_res_body
225+
end
226+
227+
112228
return _M

apisix/plugins/ai-drivers/schema.lua

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local _M = {}
18+
19+
_M.chat_request_schema = {
20+
type = "object",
21+
properties = {
22+
messages = {
23+
type = "array",
24+
minItems = 1,
25+
items = {
26+
properties = {
27+
role = {
28+
type = "string",
29+
enum = {"system", "user", "assistant"}
30+
},
31+
content = {
32+
type = "string",
33+
minLength = "1",
34+
},
35+
},
36+
additionalProperties = false,
37+
required = {"role", "content"},
38+
},
39+
}
40+
},
41+
required = {"messages"}
42+
}
43+
44+
return _M

0 commit comments

Comments
 (0)