Skip to content

Commit ac36c92

Browse files
committed
Initial upgrade implementation
1 parent 6c91aef commit ac36c92

File tree

2 files changed

+183
-18
lines changed

2 files changed

+183
-18
lines changed

http/server.lua

+82-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ local errno = require 'errno'
1616
local DETACHED = 101
1717

1818
local function errorf(fmt, ...)
19-
error(string.format(fmt, ...))
19+
error(string.format(fmt, ...), 3)
2020
end
2121

2222
local function sprintf(fmt, ...)
@@ -28,11 +28,23 @@ local function assertf(ok, fmt, ...)
2828
if select('#', ...) > 0 then
2929
fmt = tostring(fmt):format(...)
3030
end
31-
error(fmt, 2)
31+
error(fmt, 3)
3232
end
3333
return ok
3434
end
3535

36+
local function is_callable(obj)
37+
local t_obj = type(obj)
38+
if t_obj == 'function' then
39+
return true
40+
end
41+
if t_obj == 'table' then
42+
local mt = getmetatable(obj)
43+
return (type(mt) == 'table' and type(mt.__call) == 'function')
44+
end
45+
return false
46+
end
47+
3648
local function uri_escape(str)
3749
local res = {}
3850
if type(str) == 'table' then
@@ -924,7 +936,7 @@ local function url_for_httpd(httpd, name, args, query)
924936
end
925937
end
926938

927-
local function httpd_http11_parse_request(session, request_raw)
939+
local function httpd_parse_request(request_raw)
928940
local request_parsed = lib._parse_request(request_raw)
929941
if request_parsed.error then
930942
return nil, request_parsed.error
@@ -934,6 +946,14 @@ local function httpd_http11_parse_request(session, request_raw)
934946
request_parsed.path:find("./", nil, true) ~= nil then
935947
return nil, "invalid uri"
936948
end
949+
return request_parsed
950+
end
951+
952+
local function httpd_http11_parse_request(session, request_raw)
953+
local request_parsed, err = httpd_parse_request(request_raw)
954+
if not request_parsed then
955+
return nil, err
956+
end
937957
request_parsed.httpd = session.server
938958
request_parsed.s = session.socket
939959
request_parsed.peer = session.peer
@@ -970,6 +990,30 @@ local function httpd_http11_handler(session)
970990
return
971991
end
972992

993+
if p.headers['upgrade'] then
994+
local proto_name = p.headers['upgrade']:lower()
995+
local proto = session.server.upgrades[proto_name]
996+
if not proto then
997+
session:write('HTTP/1.1 400 Bad Request\r\n\r\n')
998+
return false
999+
else
1000+
local ok, upgrade_ok = pcall(proto.upgrade, session, p)
1001+
if not ok then
1002+
log.error("Failed to upgrade to '%s': %s", p.headers['upgrade'],
1003+
upgrade_ok)
1004+
session:write('HTTP/1.1 500 Internal Error\r\n\r\n')
1005+
return false
1006+
elseif not upgrade_ok then
1007+
-- TODO: should we close connection, or we should retry again
1008+
return false
1009+
end
1010+
1011+
session.ctx.proto = proto.name
1012+
session.ctx.handler = proto.handler
1013+
return true
1014+
end
1015+
end
1016+
9731017
if p.headers['expect'] == '100-continue' then
9741018
session:write('HTTP/1.1 100 Continue\r\n\r\n')
9751019
elseif p.headers['expect'] then
@@ -1174,14 +1218,32 @@ local function httpd_start(self)
11741218
return self
11751219
end
11761220

1221+
local function httpd_register_extension(self, ext_type, opts)
1222+
if ext_type:lower() == 'upgrade' then
1223+
assertf(type(opts) == 'table',
1224+
"Upgrade extension argument should be table")
1225+
assertf(type(opts.name) == 'string',
1226+
"Upgrade extension name should be %s", 'options.name', 'string')
1227+
assertf(is_callable(opts.upgrade),
1228+
"Upgrade extension callback should be callable")
1229+
assertf(is_callable(opts.handler),
1230+
"Upgrade extension handler should be callable")
1231+
1232+
self.upgrades[opts.name:lower()] = table.copy(opts)
1233+
else
1234+
errorf('Unknown extension type: %s', ext_type)
1235+
end
1236+
end
1237+
11771238
local httpd_methods = {
1178-
stop = httpd_stop,
1179-
start = httpd_start,
1180-
route = add_route,
1181-
match = match_route,
1182-
helper = set_helper,
1183-
hook = set_hook,
1184-
url_for = url_for_httpd,
1239+
stop = httpd_stop,
1240+
start = httpd_start,
1241+
route = add_route,
1242+
match = match_route,
1243+
helper = set_helper,
1244+
hook = set_hook,
1245+
url_for = url_for_httpd,
1246+
register_extension = httpd_register_extension,
11851247
}
11861248

11871249
local httpd_mt = {
@@ -1216,10 +1278,11 @@ local function httpd_new(host, port, options)
12161278
is_run = false,
12171279
options = options,
12181280

1219-
routes = { },
1220-
iroutes = { },
1221-
helpers = { url_for = url_for_helper, },
1222-
hooks = { },
1281+
routes = { },
1282+
iroutes = { },
1283+
helpers = { url_for = url_for_helper, },
1284+
hooks = { },
1285+
upgrades = { },
12231286

12241287
-- caches
12251288
cache = { tpl = {}, ctx = {}, static = {}, },
@@ -1229,6 +1292,9 @@ local function httpd_new(host, port, options)
12291292
end
12301293

12311294
return {
1232-
DETACHED = DETACHED,
1233-
new = httpd_new
1295+
DETACHED = DETACHED,
1296+
new = httpd_new,
1297+
parse_headers = httpd_parse_request,
1298+
uri_escape = uri_escape,
1299+
uri_unescape = uri_unescape,
12341300
}

test/http.test.lua

+101-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ local json = require('json')
99
local yaml = require 'yaml'
1010
local urilib = require('uri')
1111

12-
local test = tap.test("http")
13-
test:plan(7)
12+
local socket = require('socket')
13+
14+
local test = tap.test("http"); test:plan(8)
15+
1416
test:test("split_uri", function(test)
1517
test:plan(65)
1618
local function check(uri, rhs)
@@ -370,4 +372,101 @@ test:test("server requests", function(test)
370372
httpd:stop()
371373
end)
372374

375+
test:test("upgrade", function(test)
376+
test:plan(4)
377+
378+
local log = require('log')
379+
380+
local httpd = cfgserv()
381+
httpd:start()
382+
httpd:register_extension('upgrade', {
383+
name = 'exist-error-upgrade',
384+
upgrade = function() error('error') end,
385+
handler = function() end,
386+
})
387+
httpd:register_extension('upgrade', {
388+
name = 'exist-fails-upgrade',
389+
upgrade = function(session)
390+
session:write('HTTP/1.1 426 Upgrade Required\r\n\r\n')
391+
return false
392+
end,
393+
handler = function() end,
394+
})
395+
396+
local switching_header = 'HTTP/1.1 101 Switching Protocols\r\n' ..
397+
'Upgrade: exist\r\n' ..
398+
'Connection: Upgrade\r\n\r\n'
399+
httpd:register_extension('upgrade', {
400+
name = 'exists',
401+
upgrade = function(session)
402+
session:write(switching_header)
403+
return true
404+
end,
405+
handler = function(session)
406+
while true do
407+
local in_data = session:read(24)
408+
if in_data == '' then
409+
return false
410+
end
411+
session:write(in_data)
412+
end
413+
end,
414+
})
415+
416+
test:test("upgrade failed, no protocol", function(test)
417+
test:plan(1)
418+
local r = http_client.get('http://127.0.0.1:12345/abc', {
419+
headers = { upgrade = 'non-existent' }
420+
})
421+
test:is(r.status, 400, 'Error code is 400')
422+
end)
423+
424+
test:test("upgrade failed, error while upgrade", function(test)
425+
test:plan(1)
426+
local r = http_client.get('http://127.0.0.1:12345/abc', {
427+
headers = { upgrade = 'exist-error-upgrade' }
428+
})
429+
test:is(r.status, 500, 'Error code is 500')
430+
end)
431+
432+
test:test("upgrade failed, upgrade return false", function(test)
433+
test:plan(1)
434+
local r = http_client.get('http://127.0.0.1:12345/abc', {
435+
headers = { upgrade = 'exist-fails-upgrade' }
436+
})
437+
test:is(r.status, 426, 'Error code is 426')
438+
end)
439+
440+
local ws_get_r = "GET /abc HTTP/1.1\r\nUpgrade:exists\r\n\r\n"
441+
442+
test:test("upgrade success, simple tcp echo", function(test)
443+
test:plan(3)
444+
local sck = socket.tcp_connect('127.0.0.1', 12345)
445+
sck:write(ws_get_r)
446+
local data = ''
447+
while true do
448+
local tdata = sck:read({ delimiter = { '\n\n', '\r\n\r\n' } })
449+
if not tdata or tdata == '' or
450+
tdata:endswith('\r\n\r\n') or tdata:endswith('\n\n') then
451+
if tdata then
452+
data = data .. tdata
453+
end
454+
break
455+
end
456+
end
457+
458+
if not data:endswith('\r\n\r\n') then
459+
test:fail('automatic fail')
460+
else
461+
test:is(#data, #switching_header, 'right http upgrade len')
462+
end
463+
464+
local msg = ('x'):rep(24)
465+
sck:write(msg)
466+
local res = sck:read(#msg)
467+
test:is(#res, #msg, 'echo is ok')
468+
test:is(res, msg, 'echo is ok')
469+
end)
470+
end)
471+
373472
os.exit(test:check() == true and 0 or 1)

0 commit comments

Comments
 (0)