|
7 | 7 |
|
8 | 8 | import asyncio
|
9 | 9 | import contextlib
|
| 10 | +import copy |
10 | 11 | import gc
|
11 | 12 | import ipaddress
|
12 | 13 | import os
|
|
17 | 18 | import stat
|
18 | 19 | import tempfile
|
19 | 20 | import textwrap
|
| 21 | +import time |
20 | 22 | import unittest
|
21 | 23 | import unittest.mock
|
22 | 24 | import urllib.parse
|
23 | 25 | import warnings
|
24 | 26 | import weakref
|
| 27 | +from unittest import mock |
25 | 28 |
|
26 | 29 | import asyncpg
|
27 | 30 | from asyncpg import _testbase as tb
|
@@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
|
1989 | 1992 | await con.close()
|
1990 | 1993 |
|
1991 | 1994 |
|
| 1995 | +class TestMisbehavingServer(tb.TestCase): |
| 1996 | + """Tests for client connection behaviour given a misbehaving server.""" |
| 1997 | + |
| 1998 | + async def test_tls_upgrade_extra_data_received(self): |
| 1999 | + data = [ |
| 2000 | + # First, the server writes b"S" to signal it is willing to perform |
| 2001 | + # SSL |
| 2002 | + b"S", |
| 2003 | + # Then, the server writes an unsolicted arbitrary byte afterwards |
| 2004 | + b"N", |
| 2005 | + ] |
| 2006 | + data_received_events = [asyncio.Event() for _ in data] |
| 2007 | + |
| 2008 | + # Patch out the loop's create_connection so we can instrument the proto |
| 2009 | + # we return. |
| 2010 | + old_create_conn = self.loop.create_connection |
| 2011 | + |
| 2012 | + async def _mock_create_conn(*args, **kwargs): |
| 2013 | + transport, proto = await old_create_conn(*args, **kwargs) |
| 2014 | + old_data_received = proto.data_received |
| 2015 | + |
| 2016 | + num_received = 0 |
| 2017 | + |
| 2018 | + def _data_received(*args, **kwargs): |
| 2019 | + nonlocal num_received |
| 2020 | + # Call the original data_received method |
| 2021 | + ret = old_data_received(*args, **kwargs) |
| 2022 | + # Fire the event to signal we've received this datum now. |
| 2023 | + data_received_events[num_received].set() |
| 2024 | + num_received += 1 |
| 2025 | + return ret |
| 2026 | + |
| 2027 | + proto.data_received = _data_received |
| 2028 | + |
| 2029 | + # To deterministically provoke the race we're interested in for |
| 2030 | + # this regression test, wait for all data to be received before |
| 2031 | + # returning from create_connection(). |
| 2032 | + await data_received_events[-1].wait() |
| 2033 | + return transport, proto |
| 2034 | + |
| 2035 | + server = tb.InstrumentedServer(data, data_received_events) |
| 2036 | + conn_spec = await server.start() |
| 2037 | + |
| 2038 | + # The call to connect() should raise a ConnectionResetError as the |
| 2039 | + # server will close the connection after writing all the data. |
| 2040 | + with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn), |
| 2041 | + self.assertRaises(ConnectionResetError)): |
| 2042 | + await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop) |
| 2043 | + |
| 2044 | + server.stop() |
| 2045 | + |
| 2046 | + |
1992 | 2047 | def _get_connected_host(con):
|
1993 | 2048 | peername = con._transport.get_extra_info('peername')
|
1994 | 2049 | if isinstance(peername, tuple):
|
|
0 commit comments