Skip to content

Commit 8179784

Browse files
authored
feat: add ai-request-rewrite plugin (#12036)
1 parent 32c9baf commit 8179784

File tree

7 files changed

+1406
-1
lines changed

7 files changed

+1406
-1
lines changed

apisix/cli/config.lua

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ local _M = {
271271
"serverless-post-function",
272272
"ext-plugin-post-req",
273273
"ext-plugin-post-resp",
274+
"ai-request-rewrite",
274275
},
275276
stream_plugins = { "ip-restriction", "limit-conn", "mqtt-proxy", "syslog" },
276277
plugin_attr = {

apisix/plugins/ai-request-rewrite.lua

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local core = require("apisix.core")
18+
local require = require
19+
local pcall = pcall
20+
local ngx = ngx
21+
local req_set_body_data = ngx.req.set_body_data
22+
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
23+
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
24+
25+
local plugin_name = "ai-request-rewrite"
26+
27+
local auth_item_schema = {
28+
type = "object",
29+
patternProperties = {
30+
["^[a-zA-Z0-9._-]+$"] = {
31+
type = "string"
32+
}
33+
}
34+
}
35+
36+
local auth_schema = {
37+
type = "object",
38+
properties = {
39+
header = auth_item_schema,
40+
query = auth_item_schema
41+
},
42+
additionalProperties = false
43+
}
44+
45+
local model_options_schema = {
46+
description = "Key/value settings for the model",
47+
type = "object",
48+
properties = {
49+
model = {
50+
type = "string",
51+
description = "Model to execute. Examples: \"gpt-3.5-turbo\" for openai, " ..
52+
"\"deepseek-chat\" for deekseek, or \"qwen-turbo\" for openai-compatible services"
53+
}
54+
},
55+
additionalProperties = true
56+
}
57+
58+
local schema = {
59+
type = "object",
60+
properties = {
61+
prompt = {
62+
type = "string",
63+
description = "The prompt to rewrite client request."
64+
},
65+
provider = {
66+
type = "string",
67+
description = "Name of the AI service provider.",
68+
enum = {"openai", "openai-compatible", "deepseek"} -- add more providers later
69+
},
70+
auth = auth_schema,
71+
options = model_options_schema,
72+
timeout = {
73+
type = "integer",
74+
minimum = 1,
75+
maximum = 60000,
76+
default = 30000,
77+
description = "Total timeout in milliseconds for requests to LLM service, " ..
78+
"including connect, send, and read timeouts."
79+
},
80+
keepalive = {
81+
type = "boolean",
82+
default = true
83+
},
84+
keepalive_pool = {
85+
type = "integer",
86+
minimum = 1,
87+
default = 30
88+
},
89+
ssl_verify = {
90+
type = "boolean",
91+
default = true
92+
},
93+
override = {
94+
type = "object",
95+
properties = {
96+
endpoint = {
97+
type = "string",
98+
description = "To be specified to override " ..
99+
"the endpoint of the AI service provider."
100+
}
101+
}
102+
}
103+
},
104+
required = {"prompt", "provider", "auth"}
105+
}
106+
107+
local _M = {
108+
version = 0.1,
109+
priority = 1073,
110+
name = plugin_name,
111+
schema = schema
112+
}
113+
114+
local function request_to_llm(conf, request_table, ctx)
115+
local ok, ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. conf.provider)
116+
if not ok then
117+
return nil, nil, "failed to load ai-driver: " .. conf.provider
118+
end
119+
120+
local extra_opts = {
121+
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
122+
query_params = conf.auth.query or {},
123+
headers = (conf.auth.header or {}),
124+
model_options = conf.options
125+
}
126+
127+
local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
128+
if err then
129+
return nil, nil, err
130+
end
131+
132+
local resp_body, err = res:read_body()
133+
httpc:close()
134+
if err then
135+
return nil, nil, err
136+
end
137+
138+
return res, resp_body
139+
end
140+
141+
142+
local function parse_llm_response(res_body)
143+
local response_table, err = core.json.decode(res_body)
144+
145+
if err then
146+
return nil, "failed to decode llm response " .. ", err: " .. err
147+
end
148+
149+
if not response_table.choices or not response_table.choices[1] then
150+
return nil, "'choices' not in llm response"
151+
end
152+
153+
local message = response_table.choices[1].message
154+
if not message then
155+
return nil, "'message' not in llm response choices"
156+
end
157+
158+
return message.content
159+
end
160+
161+
162+
function _M.check_schema(conf)
163+
-- openai-compatible should be used with override.endpoint
164+
if conf.provider == "openai-compatible" then
165+
local override = conf.override
166+
167+
if not override or not override.endpoint then
168+
return false, "override.endpoint is required for openai-compatible provider"
169+
end
170+
end
171+
172+
return core.schema.check(schema, conf)
173+
end
174+
175+
176+
function _M.access(conf, ctx)
177+
local client_request_body, err = core.request.get_body()
178+
if err then
179+
core.log.warn("failed to get request body: ", err)
180+
return HTTP_BAD_REQUEST
181+
end
182+
183+
if not client_request_body then
184+
core.log.warn("missing request body")
185+
return
186+
end
187+
188+
-- Prepare request for LLM service
189+
local ai_request_table = {
190+
messages = {
191+
{
192+
role = "system",
193+
content = conf.prompt
194+
},
195+
{
196+
role = "user",
197+
content = client_request_body
198+
}
199+
},
200+
stream = false
201+
}
202+
203+
-- Send request to LLM service
204+
local res, resp_body, err = request_to_llm(conf, ai_request_table, ctx)
205+
if err then
206+
core.log.error("failed to request to LLM service: ", err)
207+
return HTTP_INTERNAL_SERVER_ERROR
208+
end
209+
210+
-- Handle LLM response
211+
if res.status > 299 then
212+
core.log.error("LLM service returned error status: ", res.status)
213+
return HTTP_INTERNAL_SERVER_ERROR
214+
end
215+
216+
-- Parse LLM response
217+
local llm_response, err = parse_llm_response(resp_body)
218+
if err then
219+
core.log.error("failed to parse LLM response: ", err)
220+
return HTTP_INTERNAL_SERVER_ERROR
221+
end
222+
223+
req_set_body_data(llm_response)
224+
end
225+
226+
return _M

docs/en/latest/config.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@
103103
"plugins/ai-proxy",
104104
"plugins/ai-proxy-multi",
105105
"plugins/attach-consumer-label",
106-
"plugins/ai-rag"
106+
"plugins/ai-rag",
107+
"plugins/ai-request-rewrite"
107108
]
108109
},
109110
{

0 commit comments

Comments
 (0)