Skip to content

Commit 622ed9c

Browse files
committed
Implement loop.start_tls()
Side change: no longer defer "start_reading()" call after "connection_made()". The reading should start synchronously to copy asyncio behaviour. The race condition in sslproto.py that prompted that change has been fixed.
1 parent eb2afa6 commit 622ed9c

File tree

4 files changed

+456
-24
lines changed

4 files changed

+456
-24
lines changed

tests/test_tcp.py

+373
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,9 @@ class _TestSSL(tb.SSLTestCase):
11521152
ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
11531153
ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')
11541154

1155+
PAYLOAD_SIZE = 1024 * 100
1156+
TIMEOUT = 60
1157+
11551158
def test_create_server_ssl_1(self):
11561159
CNT = 0 # number of clients that were successful
11571160
TOTAL_CNT = 25 # total number of clients that test will create
@@ -1418,6 +1421,21 @@ async def client(addr):
14181421
with self.assertRaises(exc_type):
14191422
self.loop.run_until_complete(client(srv.addr))
14201423

1424+
def test_start_tls_wrong_args(self):
1425+
if self.implementation == 'asyncio':
1426+
raise unittest.SkipTest()
1427+
1428+
async def main():
1429+
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
1430+
await self.loop.start_tls(None, None, None)
1431+
1432+
sslctx = self._create_server_ssl_context(
1433+
self.ONLYCERT, self.ONLYKEY)
1434+
with self.assertRaisesRegex(TypeError, 'is not supported'):
1435+
await self.loop.start_tls(None, None, sslctx)
1436+
1437+
self.loop.run_until_complete(main())
1438+
14211439
def test_ssl_handshake_timeout(self):
14221440
if self.implementation == 'asyncio':
14231441
raise unittest.SkipTest()
@@ -1480,6 +1498,361 @@ def test_ssl_connect_accepted_socket(self):
14801498
Test_UV_TCP.test_connect_accepted_socket(
14811499
self, server_context, client_context)
14821500

1501+
def test_start_tls_client_corrupted_ssl(self):
1502+
if self.implementation == 'asyncio':
1503+
raise unittest.SkipTest()
1504+
1505+
self.loop.set_exception_handler(lambda loop, ctx: None)
1506+
1507+
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1508+
client_sslctx = self._create_client_ssl_context()
1509+
1510+
def server(sock):
1511+
orig_sock = sock.dup()
1512+
try:
1513+
sock.starttls(
1514+
sslctx,
1515+
server_side=True)
1516+
sock.sendall(b'A\n')
1517+
sock.recv_all(1)
1518+
orig_sock.send(b'please corrupt the SSL connection')
1519+
except ssl.SSLError:
1520+
pass
1521+
finally:
1522+
sock.close()
1523+
orig_sock.close()
1524+
1525+
async def client(addr):
1526+
reader, writer = await asyncio.open_connection(
1527+
*addr,
1528+
ssl=client_sslctx,
1529+
server_hostname='',
1530+
loop=self.loop)
1531+
1532+
self.assertEqual(await reader.readline(), b'A\n')
1533+
writer.write(b'B')
1534+
with self.assertRaises(ssl.SSLError):
1535+
await reader.readline()
1536+
writer.close()
1537+
return 'OK'
1538+
1539+
with self.tcp_server(server,
1540+
max_clients=1,
1541+
backlog=1) as srv:
1542+
1543+
res = self.loop.run_until_complete(client(srv.addr))
1544+
1545+
self.assertEqual(res, 'OK')
1546+
1547+
def test_start_tls_client_reg_proto_1(self):
1548+
if self.implementation == 'asyncio':
1549+
raise unittest.SkipTest()
1550+
1551+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1552+
1553+
server_context = self._create_server_ssl_context(
1554+
self.ONLYCERT, self.ONLYKEY)
1555+
client_context = self._create_client_ssl_context()
1556+
1557+
def serve(sock):
1558+
sock.settimeout(self.TIMEOUT)
1559+
1560+
data = sock.recv_all(len(HELLO_MSG))
1561+
self.assertEqual(len(data), len(HELLO_MSG))
1562+
1563+
sock.starttls(server_context, server_side=True)
1564+
1565+
sock.sendall(b'O')
1566+
data = sock.recv_all(len(HELLO_MSG))
1567+
self.assertEqual(len(data), len(HELLO_MSG))
1568+
1569+
sock.shutdown(socket.SHUT_RDWR)
1570+
sock.close()
1571+
1572+
class ClientProto(asyncio.Protocol):
1573+
def __init__(self, on_data, on_eof):
1574+
self.on_data = on_data
1575+
self.on_eof = on_eof
1576+
self.con_made_cnt = 0
1577+
1578+
def connection_made(proto, tr):
1579+
proto.con_made_cnt += 1
1580+
# Ensure connection_made gets called only once.
1581+
self.assertEqual(proto.con_made_cnt, 1)
1582+
1583+
def data_received(self, data):
1584+
self.on_data.set_result(data)
1585+
1586+
def eof_received(self):
1587+
self.on_eof.set_result(True)
1588+
1589+
async def client(addr):
1590+
await asyncio.sleep(0.5, loop=self.loop)
1591+
1592+
on_data = self.loop.create_future()
1593+
on_eof = self.loop.create_future()
1594+
1595+
tr, proto = await self.loop.create_connection(
1596+
lambda: ClientProto(on_data, on_eof), *addr)
1597+
1598+
tr.write(HELLO_MSG)
1599+
new_tr = await self.loop.start_tls(tr, proto, client_context)
1600+
1601+
self.assertEqual(await on_data, b'O')
1602+
new_tr.write(HELLO_MSG)
1603+
await on_eof
1604+
1605+
new_tr.close()
1606+
1607+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1608+
self.loop.run_until_complete(
1609+
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
1610+
1611+
def test_start_tls_client_buf_proto_1(self):
1612+
if self.implementation == 'asyncio':
1613+
raise unittest.SkipTest()
1614+
1615+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1616+
1617+
server_context = self._create_server_ssl_context(
1618+
self.ONLYCERT, self.ONLYKEY)
1619+
client_context = self._create_client_ssl_context()
1620+
1621+
client_con_made_calls = 0
1622+
1623+
def serve(sock):
1624+
sock.settimeout(self.TIMEOUT)
1625+
1626+
data = sock.recv_all(len(HELLO_MSG))
1627+
self.assertEqual(len(data), len(HELLO_MSG))
1628+
1629+
sock.starttls(server_context, server_side=True)
1630+
1631+
sock.sendall(b'O')
1632+
data = sock.recv_all(len(HELLO_MSG))
1633+
self.assertEqual(len(data), len(HELLO_MSG))
1634+
1635+
sock.sendall(b'2')
1636+
data = sock.recv_all(len(HELLO_MSG))
1637+
self.assertEqual(len(data), len(HELLO_MSG))
1638+
1639+
sock.shutdown(socket.SHUT_RDWR)
1640+
sock.close()
1641+
1642+
class ClientProtoFirst(asyncio.BaseProtocol):
1643+
def __init__(self, on_data):
1644+
self.on_data = on_data
1645+
self.buf = bytearray(1)
1646+
1647+
def connection_made(self, tr):
1648+
nonlocal client_con_made_calls
1649+
client_con_made_calls += 1
1650+
1651+
def get_buffer(self, sizehint):
1652+
return self.buf
1653+
1654+
def buffer_updated(self, nsize):
1655+
assert nsize == 1
1656+
self.on_data.set_result(bytes(self.buf[:nsize]))
1657+
1658+
def eof_received(self):
1659+
pass
1660+
1661+
class ClientProtoSecond(asyncio.Protocol):
1662+
def __init__(self, on_data, on_eof):
1663+
self.on_data = on_data
1664+
self.on_eof = on_eof
1665+
self.con_made_cnt = 0
1666+
1667+
def connection_made(self, tr):
1668+
nonlocal client_con_made_calls
1669+
client_con_made_calls += 1
1670+
1671+
def data_received(self, data):
1672+
self.on_data.set_result(data)
1673+
1674+
def eof_received(self):
1675+
self.on_eof.set_result(True)
1676+
1677+
async def client(addr):
1678+
await asyncio.sleep(0.5, loop=self.loop)
1679+
1680+
on_data1 = self.loop.create_future()
1681+
on_data2 = self.loop.create_future()
1682+
on_eof = self.loop.create_future()
1683+
1684+
tr, proto = await self.loop.create_connection(
1685+
lambda: ClientProtoFirst(on_data1), *addr)
1686+
1687+
tr.write(HELLO_MSG)
1688+
new_tr = await self.loop.start_tls(tr, proto, client_context)
1689+
1690+
self.assertEqual(await on_data1, b'O')
1691+
new_tr.write(HELLO_MSG)
1692+
1693+
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
1694+
self.assertEqual(await on_data2, b'2')
1695+
new_tr.write(HELLO_MSG)
1696+
await on_eof
1697+
1698+
new_tr.close()
1699+
1700+
# connection_made() should be called only once -- when
1701+
# we establish connection for the first time. Start TLS
1702+
# doesn't call connection_made() on application protocols.
1703+
self.assertEqual(client_con_made_calls, 1)
1704+
1705+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1706+
self.loop.run_until_complete(
1707+
asyncio.wait_for(client(srv.addr),
1708+
loop=self.loop, timeout=self.TIMEOUT))
1709+
1710+
def test_start_tls_slow_client_cancel(self):
1711+
if self.implementation == 'asyncio':
1712+
raise unittest.SkipTest()
1713+
1714+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1715+
1716+
client_context = self._create_client_ssl_context()
1717+
server_waits_on_handshake = self.loop.create_future()
1718+
1719+
def serve(sock):
1720+
sock.settimeout(self.TIMEOUT)
1721+
1722+
data = sock.recv_all(len(HELLO_MSG))
1723+
self.assertEqual(len(data), len(HELLO_MSG))
1724+
1725+
try:
1726+
self.loop.call_soon_threadsafe(
1727+
server_waits_on_handshake.set_result, None)
1728+
data = sock.recv_all(1024 * 1024)
1729+
except ConnectionAbortedError:
1730+
pass
1731+
finally:
1732+
sock.close()
1733+
1734+
class ClientProto(asyncio.Protocol):
1735+
def __init__(self, on_data, on_eof):
1736+
self.on_data = on_data
1737+
self.on_eof = on_eof
1738+
self.con_made_cnt = 0
1739+
1740+
def connection_made(proto, tr):
1741+
proto.con_made_cnt += 1
1742+
# Ensure connection_made gets called only once.
1743+
self.assertEqual(proto.con_made_cnt, 1)
1744+
1745+
def data_received(self, data):
1746+
self.on_data.set_result(data)
1747+
1748+
def eof_received(self):
1749+
self.on_eof.set_result(True)
1750+
1751+
async def client(addr):
1752+
await asyncio.sleep(0.5, loop=self.loop)
1753+
1754+
on_data = self.loop.create_future()
1755+
on_eof = self.loop.create_future()
1756+
1757+
tr, proto = await self.loop.create_connection(
1758+
lambda: ClientProto(on_data, on_eof), *addr)
1759+
1760+
tr.write(HELLO_MSG)
1761+
1762+
await server_waits_on_handshake
1763+
1764+
with self.assertRaises(asyncio.TimeoutError):
1765+
await asyncio.wait_for(
1766+
self.loop.start_tls(tr, proto, client_context),
1767+
0.5,
1768+
loop=self.loop)
1769+
1770+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1771+
self.loop.run_until_complete(
1772+
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
1773+
1774+
def test_start_tls_server_1(self):
1775+
if self.implementation == 'asyncio':
1776+
raise unittest.SkipTest()
1777+
1778+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1779+
1780+
server_context = self._create_server_ssl_context(
1781+
self.ONLYCERT, self.ONLYKEY)
1782+
client_context = self._create_client_ssl_context()
1783+
1784+
def client(sock, addr):
1785+
sock.settimeout(self.TIMEOUT)
1786+
1787+
sock.connect(addr)
1788+
data = sock.recv_all(len(HELLO_MSG))
1789+
self.assertEqual(len(data), len(HELLO_MSG))
1790+
1791+
sock.starttls(client_context)
1792+
sock.sendall(HELLO_MSG)
1793+
1794+
sock.shutdown(socket.SHUT_RDWR)
1795+
sock.close()
1796+
1797+
class ServerProto(asyncio.Protocol):
1798+
def __init__(self, on_con, on_eof, on_con_lost):
1799+
self.on_con = on_con
1800+
self.on_eof = on_eof
1801+
self.on_con_lost = on_con_lost
1802+
self.data = b''
1803+
1804+
def connection_made(self, tr):
1805+
self.on_con.set_result(tr)
1806+
1807+
def data_received(self, data):
1808+
self.data += data
1809+
1810+
def eof_received(self):
1811+
self.on_eof.set_result(1)
1812+
1813+
def connection_lost(self, exc):
1814+
if exc is None:
1815+
self.on_con_lost.set_result(None)
1816+
else:
1817+
self.on_con_lost.set_exception(exc)
1818+
1819+
async def main(proto, on_con, on_eof, on_con_lost):
1820+
tr = await on_con
1821+
tr.write(HELLO_MSG)
1822+
1823+
self.assertEqual(proto.data, b'')
1824+
1825+
new_tr = await self.loop.start_tls(
1826+
tr, proto, server_context,
1827+
server_side=True,
1828+
ssl_handshake_timeout=self.TIMEOUT)
1829+
1830+
await on_eof
1831+
await on_con_lost
1832+
self.assertEqual(proto.data, HELLO_MSG)
1833+
new_tr.close()
1834+
1835+
async def run_main():
1836+
on_con = self.loop.create_future()
1837+
on_eof = self.loop.create_future()
1838+
on_con_lost = self.loop.create_future()
1839+
proto = ServerProto(on_con, on_eof, on_con_lost)
1840+
1841+
server = await self.loop.create_server(
1842+
lambda: proto, '127.0.0.1', 0)
1843+
addr = server.sockets[0].getsockname()
1844+
1845+
with self.tcp_client(lambda sock: client(sock, addr),
1846+
timeout=self.TIMEOUT):
1847+
await asyncio.wait_for(
1848+
main(proto, on_con, on_eof, on_con_lost),
1849+
loop=self.loop, timeout=self.TIMEOUT)
1850+
1851+
server.close()
1852+
await server.wait_closed()
1853+
1854+
self.loop.run_until_complete(run_main())
1855+
14831856

14841857
class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
14851858
pass

0 commit comments

Comments
 (0)