Skip to content

Commit facdf08

Browse files
committed
Merge remote-tracking branch 'origin/improve_handshake'
2 parents c371c00 + ed3b2bc commit facdf08

File tree

4 files changed

+273
-59
lines changed

4 files changed

+273
-59
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ from trio_websocket import open_websocket_url
3737

3838
async def main():
3939
try:
40-
async with open_websocket_url('ws://localhost/foo') as conn:
41-
await conn.send_message('hello world!')
40+
async with open_websocket_url('ws://localhost/foo') as ws:
41+
await ws.send_message('hello world!')
4242
except OSError as ose:
4343
logging.error('Connection attempt failed: %s', ose)
4444

@@ -61,11 +61,12 @@ to each incoming message with an identical outgoing message.
6161
import trio
6262
from trio_websocket import serve_websocket, ConnectionClosed
6363

64-
async def echo_server(websocket):
64+
async def echo_server(request):
65+
ws = await request.accept()
6566
while True:
6667
try:
67-
message = await websocket.get_message()
68-
await websocket.send_message(message)
68+
message = await ws.get_message()
69+
await ws.send_message(message)
6970
except ConnectionClosed:
7071
break
7172

examples/server.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ async def main(args):
4949
await serve_websocket(handler, host, args.port, ssl_context)
5050

5151

52-
async def handler(websocket):
52+
async def handler(request):
5353
''' Reverse incoming websocket messages and send them back. '''
54-
logging.info('Handler starting (path=%s)' % websocket.path)
54+
logging.info('Handler starting on path "%s"' % request.url.path_qs)
55+
ws = await request.accept()
5556
while True:
5657
try:
57-
message = await websocket.get_message()
58-
await websocket.send_message(message[::-1])
58+
message = await ws.get_message()
59+
await ws.send_message(message[::-1])
5960
except ConnectionClosed:
6061
logging.info('Connection closed')
6162
break

tests/test_connection.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
async def echo_server(nursery):
2929
''' A server that reads one message, sends back the same message,
3030
then closes the connection. '''
31-
serve_fn = partial(serve_websocket, echo_handler, HOST, 0, ssl_context=None)
31+
serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0,
32+
ssl_context=None)
3233
server = await nursery.start(serve_fn)
3334
await yield_(server)
3435

@@ -38,12 +39,20 @@ async def echo_server(nursery):
3839
async def echo_conn(echo_server):
3940
''' Return a client connection instance that is connected to an echo
4041
server. '''
41-
async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \
42-
as conn:
42+
async with open_websocket(HOST, echo_server.port, RESOURCE,
43+
use_ssl=False) as conn:
4344
await yield_(conn)
4445

4546

46-
async def echo_handler(conn):
47+
async def echo_request_handler(request):
48+
'''
49+
Accept incoming request and then pass off to echo connection handler.
50+
'''
51+
conn = await request.accept()
52+
await echo_conn_handler(conn)
53+
54+
55+
async def echo_conn_handler(conn):
4756
''' A connection handler that reads one message, sends back the same
4857
message, then exits. '''
4958
try:
@@ -95,14 +104,16 @@ async def test_listen_port_ipv6():
95104

96105

97106
async def test_server_has_listeners(nursery):
98-
server = await nursery.start(serve_websocket, echo_handler, HOST, 0, None)
107+
server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0,
108+
None)
99109
assert len(server.listeners) > 0
100110
assert isinstance(server.listeners[0], ListenPort)
101111

102112

103113
async def test_serve(nursery):
104114
task = trio.hazmat.current_task()
105-
server = await nursery.start(serve_websocket, echo_handler, HOST, 0, None)
115+
server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0,
116+
None)
106117
port = server.port
107118
assert server.port != 0
108119
# The server nursery begins with one task (server.listen).
@@ -123,17 +134,19 @@ async def test_serve_ssl(nursery):
123134
ca.configure_trust(client_context)
124135
cert = ca.issue_server_cert(HOST)
125136
cert.configure_cert(server_context)
126-
server = await nursery.start(serve_websocket, echo_handler, HOST, 0,
137+
138+
server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0,
127139
server_context)
128140
port = server.port
129-
async with open_websocket(HOST, port, RESOURCE, client_context) as conn:
141+
async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context
142+
) as conn:
130143
assert not conn.is_closed
131144

132145

133146
async def test_serve_handler_nursery(nursery):
134147
task = trio.hazmat.current_task()
135148
async with trio.open_nursery() as handler_nursery:
136-
serve_with_nursery = partial(serve_websocket, echo_handler,
149+
serve_with_nursery = partial(serve_websocket, echo_request_handler,
137150
HOST, 0, None, handler_nursery=handler_nursery)
138151
server = await nursery.start(serve_with_nursery)
139152
port = server.port
@@ -149,12 +162,12 @@ async def test_serve_handler_nursery(nursery):
149162
async def test_serve_with_zero_listeners(nursery):
150163
task = trio.hazmat.current_task()
151164
with pytest.raises(ValueError):
152-
server = WebSocketServer(echo_handler, [])
165+
server = WebSocketServer(echo_request_handler, [])
153166

154167

155168
async def test_serve_non_tcp_listener(nursery):
156169
listeners = [MemoryListener()]
157-
server = WebSocketServer(echo_handler, listeners)
170+
server = WebSocketServer(echo_request_handler, listeners)
158171
await nursery.start(server.run)
159172
assert len(server.listeners) == 1
160173
with pytest.raises(RuntimeError):
@@ -165,7 +178,7 @@ async def test_serve_non_tcp_listener(nursery):
165178
async def test_serve_multiple_listeners(nursery):
166179
listener1 = (await trio.open_tcp_listeners(0, host=HOST))[0]
167180
listener2 = MemoryListener()
168-
server = WebSocketServer(echo_handler, [listener1, listener2])
181+
server = WebSocketServer(echo_request_handler, [listener1, listener2])
169182
await nursery.start(server.run)
170183
assert len(server.listeners) == 2
171184
with pytest.raises(RuntimeError):
@@ -213,6 +226,21 @@ async def test_client_connect_url(echo_server, nursery):
213226
assert not conn.is_closed
214227

215228

229+
async def test_handshake_subprotocol(nursery):
230+
async def handler(request):
231+
assert request.proposed_subprotocols == ('chat', 'file')
232+
assert request.subprotocol is None
233+
request.subprotocol = 'chat'
234+
assert request.subprotocol == 'chat'
235+
server_ws = await request.accept()
236+
assert server_ws.subprotocol == 'chat'
237+
238+
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
239+
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False,
240+
subprotocols=('chat', 'file')) as client_ws:
241+
assert client_ws.subprotocol == 'chat'
242+
243+
216244
async def test_client_send_and_receive(echo_conn):
217245
async with echo_conn:
218246
await echo_conn.send_message('This is a test message.')
@@ -293,12 +321,13 @@ async def test_wrap_client_stream(echo_server, nursery):
293321

294322
async def test_wrap_server_stream(nursery):
295323
async def handler(stream):
296-
server = await wrap_server_stream(nursery, stream)
297-
async with server:
298-
assert not server.is_closed
299-
msg = await server.get_message()
324+
request = await wrap_server_stream(nursery, stream)
325+
server_ws = await request.accept()
326+
async with server_ws:
327+
assert not server_ws.is_closed
328+
msg = await server_ws.get_message()
300329
assert msg == 'Hello from client!'
301-
assert server.is_closed
330+
assert server_ws.is_closed
302331
serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST)
303332
listeners = await nursery.start(serve_fn)
304333
port = listeners[0].socket.getsockname()[1]
@@ -307,26 +336,28 @@ async def handler(stream):
307336

308337

309338
async def test_client_does_not_close_handshake(nursery):
310-
async def handler(server):
339+
async def handler(request):
340+
server_ws = await request.accept()
311341
with pytest.raises(ConnectionClosed):
312-
await server.get_message()
342+
await server_ws.get_message()
313343
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
314344
port = server.port
315345
stream = await trio.open_tcp_stream(HOST, server.port)
316-
client = await wrap_client_stream(nursery, stream, HOST, RESOURCE)
317-
async with client:
346+
client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE)
347+
async with client_ws:
318348
await stream.aclose()
319349
with pytest.raises(ConnectionClosed):
320-
await client.send_message('Hello from client!')
350+
await client_ws.send_message('Hello from client!')
321351

322352

323353
async def test_server_does_not_close_handshake(nursery):
324354
async def handler(stream):
325-
server = await wrap_server_stream(nursery, stream)
326-
async with server:
355+
request = await wrap_server_stream(nursery, stream)
356+
server_ws = await request.accept()
357+
async with server_ws:
327358
await stream.aclose()
328359
with pytest.raises(ConnectionClosed):
329-
await server.send_message('Hello from client!')
360+
await server_ws.send_message('Hello from client!')
330361
serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST)
331362
listeners = await nursery.start(serve_fn)
332363
port = listeners[0].socket.getsockname()[1]
@@ -336,7 +367,8 @@ async def handler(stream):
336367

337368

338369
async def test_server_handler_exit(nursery, autojump_clock):
339-
async def handler(connection):
370+
async def handler(request):
371+
server_ws = await request.accept()
340372
await trio.sleep(1)
341373

342374
server = await nursery.start(

0 commit comments

Comments
 (0)