Skip to content

Commit b69cf36

Browse files
committed
Initial upgrade implementation
1 parent d3a11ce commit b69cf36

File tree

2 files changed

+186
-17
lines changed

2 files changed

+186
-17
lines changed

Diff for: http/server.lua

+85-15
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ local errno = require 'errno'
1616
local DETACHED = 101
1717

1818
local function errorf(fmt, ...)
19-
error(string.format(fmt, ...))
19+
error(string.format(fmt, ...), 3)
2020
end
2121

2222
local function sprintf(fmt, ...)
@@ -35,6 +35,18 @@ local function is_callable(obj)
3535
return false
3636
end
3737

38+
local function is_callable(obj)
39+
local t_obj = type(obj)
40+
if t_obj == 'function' then
41+
return true
42+
end
43+
if t_obj == 'table' then
44+
local mt = getmetatable(obj)
45+
return (type(mt) == 'table' and type(mt.__call) == 'function')
46+
end
47+
return false
48+
end
49+
3850
local function uri_escape(str)
3951
local res = {}
4052
if type(str) == 'table' then
@@ -930,7 +942,7 @@ local function url_for_httpd(httpd, name, args, query)
930942
end
931943
end
932944

933-
local function httpd_http11_parse_request(session, request_raw)
945+
local function httpd_parse_request(request_raw)
934946
local request_parsed = lib._parse_request(request_raw)
935947
if request_parsed.error then
936948
return nil, request_parsed.error
@@ -940,6 +952,14 @@ local function httpd_http11_parse_request(session, request_raw)
940952
request_parsed.path:find("./", nil, true) ~= nil then
941953
return nil, "invalid uri"
942954
end
955+
return request_parsed
956+
end
957+
958+
local function httpd_http11_parse_request(session, request_raw)
959+
local request_parsed, err = httpd_parse_request(request_raw)
960+
if not request_parsed then
961+
return nil, err
962+
end
943963
request_parsed.httpd = session.server
944964
request_parsed.s = session.socket
945965
request_parsed.peer = session.peer
@@ -976,6 +996,30 @@ local function httpd_http11_handler(session)
976996
return
977997
end
978998

999+
if p.headers['upgrade'] then
1000+
local proto_name = p.headers['upgrade']:lower()
1001+
local proto = session.server.upgrades[proto_name]
1002+
if not proto then
1003+
session:write('HTTP/1.1 400 Bad Request\r\n\r\n')
1004+
return false
1005+
else
1006+
local ok, upgrade_ok = pcall(proto.upgrade, session, p)
1007+
if not ok then
1008+
log.error("Failed to upgrade to '%s': %s", p.headers['upgrade'],
1009+
upgrade_ok)
1010+
session:write('HTTP/1.1 500 Internal Error\r\n\r\n')
1011+
return false
1012+
elseif not upgrade_ok then
1013+
-- TODO: should we close connection, or we should retry again
1014+
return false
1015+
end
1016+
1017+
session.ctx.proto = proto.name
1018+
session.ctx.handler = proto.handler
1019+
return true
1020+
end
1021+
end
1022+
9791023
if p.headers['expect'] == '100-continue' then
9801024
session:write('HTTP/1.1 100 Continue\r\n\r\n')
9811025
elseif p.headers['expect'] then
@@ -1180,14 +1224,36 @@ local function httpd_start(self)
11801224
return self
11811225
end
11821226

1227+
local function httpd_register_extension(self, ext_type, opts)
1228+
if ext_type:lower() == 'upgrade' then
1229+
if not (type(opts) == 'table') then
1230+
errorf("Upgrade extension argument should be table")
1231+
end
1232+
if not (type(opts.name) == 'string') then
1233+
errorf("Upgrade extension name should be %s", 'options.name', 'string')
1234+
end
1235+
if not is_callable(opts.upgrade) then
1236+
errorf("Upgrade extension callback should be callable")
1237+
end
1238+
if not is_callable(opts.handler) then
1239+
errorf("Upgrade extension handler should be callable")
1240+
end
1241+
1242+
self.upgrades[opts.name:lower()] = table.copy(opts)
1243+
else
1244+
errorf('Unknown extension type: %s', ext_type)
1245+
end
1246+
end
1247+
11831248
local httpd_methods = {
1184-
stop = httpd_stop,
1185-
start = httpd_start,
1186-
route = add_route,
1187-
match = match_route,
1188-
helper = set_helper,
1189-
hook = set_hook,
1190-
url_for = url_for_httpd,
1249+
stop = httpd_stop,
1250+
start = httpd_start,
1251+
route = add_route,
1252+
match = match_route,
1253+
helper = set_helper,
1254+
hook = set_hook,
1255+
url_for = url_for_httpd,
1256+
register_extension = httpd_register_extension,
11911257
}
11921258

11931259
local httpd_mt = {
@@ -1223,10 +1289,11 @@ local function httpd_new(host, port, options)
12231289
is_run = false,
12241290
options = options,
12251291

1226-
routes = { },
1227-
iroutes = { },
1228-
helpers = { url_for = url_for_helper, },
1229-
hooks = { },
1292+
routes = { },
1293+
iroutes = { },
1294+
helpers = { url_for = url_for_helper, },
1295+
hooks = { },
1296+
upgrades = { },
12301297

12311298
-- caches
12321299
cache = { tpl = {}, ctx = {}, static = {}, },
@@ -1236,6 +1303,9 @@ local function httpd_new(host, port, options)
12361303
end
12371304

12381305
return {
1239-
DETACHED = DETACHED,
1240-
new = httpd_new
1306+
DETACHED = DETACHED,
1307+
new = httpd_new,
1308+
parse_headers = httpd_parse_request,
1309+
uri_escape = uri_escape,
1310+
uri_unescape = uri_unescape,
12411311
}

Diff for: test/http.test.lua

+101-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ local json = require('json')
99
local yaml = require 'yaml'
1010
local urilib = require('uri')
1111

12-
local test = tap.test("http")
13-
test:plan(7)
12+
local socket = require('socket')
13+
14+
local test = tap.test("http"); test:plan(8)
15+
1416
test:test("split_uri", function(test)
1517
test:plan(65)
1618
local function check(uri, rhs)
@@ -388,4 +390,101 @@ test:test("server requests", function(test)
388390
httpd:stop()
389391
end)
390392

393+
test:test("upgrade", function(test)
394+
test:plan(4)
395+
396+
local log = require('log')
397+
398+
local httpd = cfgserv()
399+
httpd:start()
400+
httpd:register_extension('upgrade', {
401+
name = 'exist-error-upgrade',
402+
upgrade = function() error('error') end,
403+
handler = function() end,
404+
})
405+
httpd:register_extension('upgrade', {
406+
name = 'exist-fails-upgrade',
407+
upgrade = function(session)
408+
session:write('HTTP/1.1 426 Upgrade Required\r\n\r\n')
409+
return false
410+
end,
411+
handler = function() end,
412+
})
413+
414+
local switching_header = 'HTTP/1.1 101 Switching Protocols\r\n' ..
415+
'Upgrade: exist\r\n' ..
416+
'Connection: Upgrade\r\n\r\n'
417+
httpd:register_extension('upgrade', {
418+
name = 'exists',
419+
upgrade = function(session)
420+
session:write(switching_header)
421+
return true
422+
end,
423+
handler = function(session)
424+
while true do
425+
local in_data = session:read(24)
426+
if in_data == '' then
427+
return false
428+
end
429+
session:write(in_data)
430+
end
431+
end,
432+
})
433+
434+
test:test("upgrade failed, no protocol", function(test)
435+
test:plan(1)
436+
local r = http_client.get('http://127.0.0.1:12345/abc', {
437+
headers = { upgrade = 'non-existent' }
438+
})
439+
test:is(r.status, 400, 'Error code is 400')
440+
end)
441+
442+
test:test("upgrade failed, error while upgrade", function(test)
443+
test:plan(1)
444+
local r = http_client.get('http://127.0.0.1:12345/abc', {
445+
headers = { upgrade = 'exist-error-upgrade' }
446+
})
447+
test:is(r.status, 500, 'Error code is 500')
448+
end)
449+
450+
test:test("upgrade failed, upgrade return false", function(test)
451+
test:plan(1)
452+
local r = http_client.get('http://127.0.0.1:12345/abc', {
453+
headers = { upgrade = 'exist-fails-upgrade' }
454+
})
455+
test:is(r.status, 426, 'Error code is 426')
456+
end)
457+
458+
local ws_get_r = "GET /abc HTTP/1.1\r\nUpgrade:exists\r\n\r\n"
459+
460+
test:test("upgrade success, simple tcp echo", function(test)
461+
test:plan(3)
462+
local sck = socket.tcp_connect('127.0.0.1', 12345)
463+
sck:write(ws_get_r)
464+
local data = ''
465+
while true do
466+
local tdata = sck:read({ delimiter = { '\n\n', '\r\n\r\n' } })
467+
if not tdata or tdata == '' or
468+
tdata:endswith('\r\n\r\n') or tdata:endswith('\n\n') then
469+
if tdata then
470+
data = data .. tdata
471+
end
472+
break
473+
end
474+
end
475+
476+
if not data:endswith('\r\n\r\n') then
477+
test:fail('automatic fail')
478+
else
479+
test:is(#data, #switching_header, 'right http upgrade len')
480+
end
481+
482+
local msg = ('x'):rep(24)
483+
sck:write(msg)
484+
local res = sck:read(#msg)
485+
test:is(#res, #msg, 'echo is ok')
486+
test:is(res, msg, 'echo is ok')
487+
end)
488+
end)
489+
391490
os.exit(test:check() == true and 0 or 1)

0 commit comments

Comments
 (0)