@@ -20,12 +20,20 @@ local mt = {
20
20
__index = _M
21
21
}
22
22
23
+ local CONTENT_TYPE_JSON = " application/json"
24
+
23
25
local core = require (" apisix.core" )
24
26
local http = require (" resty.http" )
25
27
local url = require (" socket.url" )
28
+ local schema = require (" apisix.plugins.ai-drivers.schema" )
29
+ local ngx_re = require (" ngx.re" )
30
+
31
+ local ngx_print = ngx .print
32
+ local ngx_flush = ngx .flush
26
33
27
34
local pairs = pairs
28
35
local type = type
36
+ local ipairs = ipairs
29
37
local setmetatable = setmetatable
30
38
31
39
@@ -40,6 +48,26 @@ function _M.new(opts)
40
48
end
41
49
42
50
51
+ function _M .validate_request (ctx )
52
+ local ct = core .request .header (ctx , " Content-Type" ) or CONTENT_TYPE_JSON
53
+ if not core .string .has_prefix (ct , CONTENT_TYPE_JSON ) then
54
+ return nil , " unsupported content-type: " .. ct .. " , only application/json is supported"
55
+ end
56
+
57
+ local request_table , err = core .request .get_json_request_body_table ()
58
+ if not request_table then
59
+ return nil , err
60
+ end
61
+
62
+ local ok , err = core .schema .check (schema .chat_request_schema , request_table )
63
+ if not ok then
64
+ return nil , " request format doesn't match schema: " .. err
65
+ end
66
+
67
+ return request_table , nil
68
+ end
69
+
70
+
43
71
function _M .request (self , conf , request_table , extra_opts )
44
72
local httpc , err = http .new ()
45
73
if not httpc then
@@ -54,11 +82,11 @@ function _M.request(self, conf, request_table, extra_opts)
54
82
end
55
83
56
84
local ok , err = httpc :connect ({
57
- scheme = endpoint and parsed_url .scheme or " https" ,
58
- host = endpoint and parsed_url .host or self .host ,
59
- port = endpoint and parsed_url .port or self .port ,
85
+ scheme = parsed_url and parsed_url .scheme or " https" ,
86
+ host = parsed_url and parsed_url .host or self .host ,
87
+ port = parsed_url and parsed_url .port or self .port ,
60
88
ssl_verify = conf .ssl_verify ,
61
- ssl_server_name = endpoint and parsed_url .host or self .host ,
89
+ ssl_server_name = parsed_url and parsed_url .host or self .host ,
62
90
pool_size = conf .keepalive and conf .keepalive_pool ,
63
91
})
64
92
@@ -75,7 +103,7 @@ function _M.request(self, conf, request_table, extra_opts)
75
103
end
76
104
end
77
105
78
- local path = (endpoint and parsed_url .path or self .path )
106
+ local path = (parsed_url and parsed_url .path or self .path )
79
107
80
108
local headers = extra_opts .headers
81
109
headers [" Content-Type" ] = " application/json"
@@ -106,7 +134,95 @@ function _M.request(self, conf, request_table, extra_opts)
106
134
return nil , err
107
135
end
108
136
109
- return res , nil , httpc
137
+ return res , nil
110
138
end
111
139
140
+
141
+ function _M .read_response (ctx , res )
142
+ local body_reader = res .body_reader
143
+ if not body_reader then
144
+ core .log .error (" AI service sent no response body" )
145
+ return 500
146
+ end
147
+
148
+ local content_type = res .headers [" Content-Type" ]
149
+ core .response .set_header (" Content-Type" , content_type )
150
+
151
+ if core .string .find (content_type , " text/event-stream" ) then
152
+ while true do
153
+ local chunk , err = body_reader () -- will read chunk by chunk
154
+ if err then
155
+ core .log .error (" failed to read response chunk: " , err )
156
+ break
157
+ end
158
+ if not chunk then
159
+ break
160
+ end
161
+
162
+ ngx_print (chunk )
163
+ ngx_flush (true )
164
+
165
+ local events , err = ngx_re .split (chunk , " \n " )
166
+ if err then
167
+ core .log .warn (" failed to split response chunk [" , chunk , " ] to events: " , err )
168
+ goto CONTINUE
169
+ end
170
+
171
+ for _ , event in ipairs (events ) do
172
+ if not core .string .find (event , " data:" ) or core .string .find (event , " [DONE]" ) then
173
+ goto CONTINUE
174
+ end
175
+
176
+ local parts , err = ngx_re .split (event , " :" , nil , nil , 2 )
177
+ if err then
178
+ core .log .warn (" failed to split data event [" , event , " ] to parts: " , err )
179
+ goto CONTINUE
180
+ end
181
+
182
+ if # parts ~= 2 then
183
+ core .log .warn (" malformed data event: " , event )
184
+ goto CONTINUE
185
+ end
186
+
187
+ local data , err = core .json .decode (parts [2 ])
188
+ if err then
189
+ core .log .warn (" failed to decode data event [" , parts [2 ], " ] to json: " , err )
190
+ goto CONTINUE
191
+ end
192
+
193
+ -- usage field is null for non-last events, null is parsed as userdata type
194
+ if data and data .usage and type (data .usage ) ~= " userdata" then
195
+ ctx .ai_token_usage = {
196
+ prompt_tokens = data .usage .prompt_tokens or 0 ,
197
+ completion_tokens = data .usage .completion_tokens or 0 ,
198
+ total_tokens = data .usage .total_tokens or 0 ,
199
+ }
200
+ end
201
+ end
202
+
203
+ :: CONTINUE::
204
+ end
205
+ return
206
+ end
207
+
208
+ local raw_res_body , err = res :read_body ()
209
+ if not raw_res_body then
210
+ core .log .error (" failed to read response body: " , err )
211
+ return 500
212
+ end
213
+ local res_body , err = core .json .decode (raw_res_body )
214
+ if err then
215
+ core .log .warn (" invalid response body from ai service: " , raw_res_body , " err: " , err ,
216
+ " , it will cause token usage not available" )
217
+ else
218
+ ctx .ai_token_usage = {
219
+ prompt_tokens = res_body .usage and res_body .usage .prompt_tokens or 0 ,
220
+ completion_tokens = res_body .usage and res_body .usage .completion_tokens or 0 ,
221
+ total_tokens = res_body .usage and res_body .usage .total_tokens or 0 ,
222
+ }
223
+ end
224
+ return res .status , raw_res_body
225
+ end
226
+
227
+
112
228
return _M
0 commit comments