Skip to content

Commit 087f79f

Browse files
committed
TLSUpgradeProto: don't set multiple results for an event
In the case of a misbehaving server, the client may receive more than one byte in separate data_received() invocations from the server. While we can't do much sane with this, we should handle it gracefully and not crash with asyncio.InvalidStateError when trying to set another result on the event. Fixes #729
1 parent c2c8d20 commit 087f79f

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

asyncpg/_testbase/__init__.py

+39
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import logging
1414
import os
1515
import re
16+
import socket
1617
import textwrap
1718
import time
1819
import traceback
@@ -525,3 +526,41 @@ def connect_standby(cls, **kwargs):
525526
kwargs
526527
)
527528
return pg_connection.connect(**conn_spec, loop=cls.loop)
529+
530+
531+
class InstrumentedServer:
532+
"""
533+
A socket server for testing.
534+
It will write each item from `data`, and wait for the corresponding event
535+
in `received_events` to notify that it was received before writing the next
536+
item from `data`.
537+
"""
538+
def __init__(self, data, received_events):
539+
assert len(data) == len(received_events)
540+
self._data = data
541+
self._server = None
542+
self._received_events = received_events
543+
544+
async def _handle_client(self, _reader, writer):
545+
for datum, received_event in zip(self._data, self._received_events):
546+
writer.write(datum)
547+
await writer.drain()
548+
await received_event.wait()
549+
550+
writer.close()
551+
await writer.wait_closed()
552+
553+
async def start(self):
554+
"""Start the server."""
555+
self._server = await asyncio.start_server(self._handle_client, 'localhost', 0)
556+
assert len(self._server.sockets) == 1
557+
sock = self._server.sockets[0]
558+
addr, port = sock.getsockname()
559+
return {
560+
'host': addr,
561+
'port': port,
562+
}
563+
564+
def stop(self):
565+
"""Stop the server."""
566+
self._server.close()

asyncpg/connect_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
714714
self.ssl_is_advisory = ssl_is_advisory
715715

716716
def data_received(self, data):
717+
if self.on_data.done():
718+
# Only expect to receive one byte here; ignore unsolicited further
719+
# data.
720+
return
721+
717722
if data == b'S':
718723
self.on_data.set_result(True)
719724
elif (self.ssl_is_advisory and

tests/test_connect.py

+55
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import contextlib
10+
import copy
1011
import gc
1112
import ipaddress
1213
import os
@@ -17,11 +18,13 @@
1718
import stat
1819
import tempfile
1920
import textwrap
21+
import time
2022
import unittest
2123
import unittest.mock
2224
import urllib.parse
2325
import warnings
2426
import weakref
27+
from unittest import mock
2528

2629
import asyncpg
2730
from asyncpg import _testbase as tb
@@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
19891992
await con.close()
19901993

19911994

1995+
class TestMisbehavingServer(tb.TestCase):
1996+
"""Tests for client connection behaviour given a misbehaving server."""
1997+
1998+
async def test_tls_upgrade_extra_data_received(self):
1999+
data = [
2000+
# First, the server writes b"S" to signal it is willing to perform
2001+
# SSL
2002+
b"S",
2003+
# Then, the server writes an unsolicted arbitrary byte afterwards
2004+
b"N",
2005+
]
2006+
data_received_events = [asyncio.Event() for _ in data]
2007+
2008+
# Patch out the loop's create_connection so we can instrument the proto
2009+
# we return.
2010+
old_create_conn = self.loop.create_connection
2011+
2012+
async def _mock_create_conn(*args, **kwargs):
2013+
transport, proto = await old_create_conn(*args, **kwargs)
2014+
old_data_received = proto.data_received
2015+
2016+
num_received = 0
2017+
2018+
def _data_received(*args, **kwargs):
2019+
nonlocal num_received
2020+
# Call the original data_received method
2021+
ret = old_data_received(*args, **kwargs)
2022+
# Fire the event to signal we've received this datum now.
2023+
data_received_events[num_received].set()
2024+
num_received += 1
2025+
return ret
2026+
2027+
proto.data_received = _data_received
2028+
2029+
# To deterministically provoke the race we're interested in for
2030+
# this regression test, wait for all data to be received before
2031+
# returning from create_connection().
2032+
await data_received_events[-1].wait()
2033+
return transport, proto
2034+
2035+
server = tb.InstrumentedServer(data, data_received_events)
2036+
conn_spec = await server.start()
2037+
2038+
# The call to connect() should raise a ConnectionResetError as the
2039+
# server will close the connection after writing all the data.
2040+
with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn),
2041+
self.assertRaises(ConnectionResetError)):
2042+
await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop)
2043+
2044+
server.stop()
2045+
2046+
19922047
def _get_connected_host(con):
19932048
peername = con._transport.get_extra_info('peername')
19942049
if isinstance(peername, tuple):

0 commit comments

Comments
 (0)