@@ -1152,6 +1152,9 @@ class _TestSSL(tb.SSLTestCase):
1152
1152
ONLYCERT = tb ._cert_fullname (__file__ , 'ssl_cert.pem' )
1153
1153
ONLYKEY = tb ._cert_fullname (__file__ , 'ssl_key.pem' )
1154
1154
1155
+ PAYLOAD_SIZE = 1024 * 100
1156
+ TIMEOUT = 60
1157
+
1155
1158
def test_create_server_ssl_1 (self ):
1156
1159
CNT = 0 # number of clients that were successful
1157
1160
TOTAL_CNT = 25 # total number of clients that test will create
@@ -1418,6 +1421,21 @@ async def client(addr):
1418
1421
with self .assertRaises (exc_type ):
1419
1422
self .loop .run_until_complete (client (srv .addr ))
1420
1423
1424
+ def test_start_tls_wrong_args (self ):
1425
+ if self .implementation == 'asyncio' :
1426
+ raise unittest .SkipTest ()
1427
+
1428
+ async def main ():
1429
+ with self .assertRaisesRegex (TypeError , 'SSLContext, got' ):
1430
+ await self .loop .start_tls (None , None , None )
1431
+
1432
+ sslctx = self ._create_server_ssl_context (
1433
+ self .ONLYCERT , self .ONLYKEY )
1434
+ with self .assertRaisesRegex (TypeError , 'is not supported' ):
1435
+ await self .loop .start_tls (None , None , sslctx )
1436
+
1437
+ self .loop .run_until_complete (main ())
1438
+
1421
1439
def test_ssl_handshake_timeout (self ):
1422
1440
if self .implementation == 'asyncio' :
1423
1441
raise unittest .SkipTest ()
@@ -1480,6 +1498,361 @@ def test_ssl_connect_accepted_socket(self):
1480
1498
Test_UV_TCP .test_connect_accepted_socket (
1481
1499
self , server_context , client_context )
1482
1500
1501
+ def test_start_tls_client_corrupted_ssl (self ):
1502
+ if self .implementation == 'asyncio' :
1503
+ raise unittest .SkipTest ()
1504
+
1505
+ self .loop .set_exception_handler (lambda loop , ctx : None )
1506
+
1507
+ sslctx = self ._create_server_ssl_context (self .ONLYCERT , self .ONLYKEY )
1508
+ client_sslctx = self ._create_client_ssl_context ()
1509
+
1510
+ def server (sock ):
1511
+ orig_sock = sock .dup ()
1512
+ try :
1513
+ sock .starttls (
1514
+ sslctx ,
1515
+ server_side = True )
1516
+ sock .sendall (b'A\n ' )
1517
+ sock .recv_all (1 )
1518
+ orig_sock .send (b'please corrupt the SSL connection' )
1519
+ except ssl .SSLError :
1520
+ pass
1521
+ finally :
1522
+ sock .close ()
1523
+ orig_sock .close ()
1524
+
1525
+ async def client (addr ):
1526
+ reader , writer = await asyncio .open_connection (
1527
+ * addr ,
1528
+ ssl = client_sslctx ,
1529
+ server_hostname = '' ,
1530
+ loop = self .loop )
1531
+
1532
+ self .assertEqual (await reader .readline (), b'A\n ' )
1533
+ writer .write (b'B' )
1534
+ with self .assertRaises (ssl .SSLError ):
1535
+ await reader .readline ()
1536
+ writer .close ()
1537
+ return 'OK'
1538
+
1539
+ with self .tcp_server (server ,
1540
+ max_clients = 1 ,
1541
+ backlog = 1 ) as srv :
1542
+
1543
+ res = self .loop .run_until_complete (client (srv .addr ))
1544
+
1545
+ self .assertEqual (res , 'OK' )
1546
+
1547
+ def test_start_tls_client_reg_proto_1 (self ):
1548
+ if self .implementation == 'asyncio' :
1549
+ raise unittest .SkipTest ()
1550
+
1551
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1552
+
1553
+ server_context = self ._create_server_ssl_context (
1554
+ self .ONLYCERT , self .ONLYKEY )
1555
+ client_context = self ._create_client_ssl_context ()
1556
+
1557
+ def serve (sock ):
1558
+ sock .settimeout (self .TIMEOUT )
1559
+
1560
+ data = sock .recv_all (len (HELLO_MSG ))
1561
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1562
+
1563
+ sock .starttls (server_context , server_side = True )
1564
+
1565
+ sock .sendall (b'O' )
1566
+ data = sock .recv_all (len (HELLO_MSG ))
1567
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1568
+
1569
+ sock .shutdown (socket .SHUT_RDWR )
1570
+ sock .close ()
1571
+
1572
+ class ClientProto (asyncio .Protocol ):
1573
+ def __init__ (self , on_data , on_eof ):
1574
+ self .on_data = on_data
1575
+ self .on_eof = on_eof
1576
+ self .con_made_cnt = 0
1577
+
1578
+ def connection_made (proto , tr ):
1579
+ proto .con_made_cnt += 1
1580
+ # Ensure connection_made gets called only once.
1581
+ self .assertEqual (proto .con_made_cnt , 1 )
1582
+
1583
+ def data_received (self , data ):
1584
+ self .on_data .set_result (data )
1585
+
1586
+ def eof_received (self ):
1587
+ self .on_eof .set_result (True )
1588
+
1589
+ async def client (addr ):
1590
+ await asyncio .sleep (0.5 , loop = self .loop )
1591
+
1592
+ on_data = self .loop .create_future ()
1593
+ on_eof = self .loop .create_future ()
1594
+
1595
+ tr , proto = await self .loop .create_connection (
1596
+ lambda : ClientProto (on_data , on_eof ), * addr )
1597
+
1598
+ tr .write (HELLO_MSG )
1599
+ new_tr = await self .loop .start_tls (tr , proto , client_context )
1600
+
1601
+ self .assertEqual (await on_data , b'O' )
1602
+ new_tr .write (HELLO_MSG )
1603
+ await on_eof
1604
+
1605
+ new_tr .close ()
1606
+
1607
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
1608
+ self .loop .run_until_complete (
1609
+ asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1610
+
1611
+ def test_start_tls_client_buf_proto_1 (self ):
1612
+ if self .implementation == 'asyncio' :
1613
+ raise unittest .SkipTest ()
1614
+
1615
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1616
+
1617
+ server_context = self ._create_server_ssl_context (
1618
+ self .ONLYCERT , self .ONLYKEY )
1619
+ client_context = self ._create_client_ssl_context ()
1620
+
1621
+ client_con_made_calls = 0
1622
+
1623
+ def serve (sock ):
1624
+ sock .settimeout (self .TIMEOUT )
1625
+
1626
+ data = sock .recv_all (len (HELLO_MSG ))
1627
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1628
+
1629
+ sock .starttls (server_context , server_side = True )
1630
+
1631
+ sock .sendall (b'O' )
1632
+ data = sock .recv_all (len (HELLO_MSG ))
1633
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1634
+
1635
+ sock .sendall (b'2' )
1636
+ data = sock .recv_all (len (HELLO_MSG ))
1637
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1638
+
1639
+ sock .shutdown (socket .SHUT_RDWR )
1640
+ sock .close ()
1641
+
1642
+ class ClientProtoFirst (asyncio .BaseProtocol ):
1643
+ def __init__ (self , on_data ):
1644
+ self .on_data = on_data
1645
+ self .buf = bytearray (1 )
1646
+
1647
+ def connection_made (self , tr ):
1648
+ nonlocal client_con_made_calls
1649
+ client_con_made_calls += 1
1650
+
1651
+ def get_buffer (self , sizehint ):
1652
+ return self .buf
1653
+
1654
+ def buffer_updated (self , nsize ):
1655
+ assert nsize == 1
1656
+ self .on_data .set_result (bytes (self .buf [:nsize ]))
1657
+
1658
+ def eof_received (self ):
1659
+ pass
1660
+
1661
+ class ClientProtoSecond (asyncio .Protocol ):
1662
+ def __init__ (self , on_data , on_eof ):
1663
+ self .on_data = on_data
1664
+ self .on_eof = on_eof
1665
+ self .con_made_cnt = 0
1666
+
1667
+ def connection_made (self , tr ):
1668
+ nonlocal client_con_made_calls
1669
+ client_con_made_calls += 1
1670
+
1671
+ def data_received (self , data ):
1672
+ self .on_data .set_result (data )
1673
+
1674
+ def eof_received (self ):
1675
+ self .on_eof .set_result (True )
1676
+
1677
+ async def client (addr ):
1678
+ await asyncio .sleep (0.5 , loop = self .loop )
1679
+
1680
+ on_data1 = self .loop .create_future ()
1681
+ on_data2 = self .loop .create_future ()
1682
+ on_eof = self .loop .create_future ()
1683
+
1684
+ tr , proto = await self .loop .create_connection (
1685
+ lambda : ClientProtoFirst (on_data1 ), * addr )
1686
+
1687
+ tr .write (HELLO_MSG )
1688
+ new_tr = await self .loop .start_tls (tr , proto , client_context )
1689
+
1690
+ self .assertEqual (await on_data1 , b'O' )
1691
+ new_tr .write (HELLO_MSG )
1692
+
1693
+ new_tr .set_protocol (ClientProtoSecond (on_data2 , on_eof ))
1694
+ self .assertEqual (await on_data2 , b'2' )
1695
+ new_tr .write (HELLO_MSG )
1696
+ await on_eof
1697
+
1698
+ new_tr .close ()
1699
+
1700
+ # connection_made() should be called only once -- when
1701
+ # we establish connection for the first time. Start TLS
1702
+ # doesn't call connection_made() on application protocols.
1703
+ self .assertEqual (client_con_made_calls , 1 )
1704
+
1705
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
1706
+ self .loop .run_until_complete (
1707
+ asyncio .wait_for (client (srv .addr ),
1708
+ loop = self .loop , timeout = self .TIMEOUT ))
1709
+
1710
+ def test_start_tls_slow_client_cancel (self ):
1711
+ if self .implementation == 'asyncio' :
1712
+ raise unittest .SkipTest ()
1713
+
1714
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1715
+
1716
+ client_context = self ._create_client_ssl_context ()
1717
+ server_waits_on_handshake = self .loop .create_future ()
1718
+
1719
+ def serve (sock ):
1720
+ sock .settimeout (self .TIMEOUT )
1721
+
1722
+ data = sock .recv_all (len (HELLO_MSG ))
1723
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1724
+
1725
+ try :
1726
+ self .loop .call_soon_threadsafe (
1727
+ server_waits_on_handshake .set_result , None )
1728
+ data = sock .recv_all (1024 * 1024 )
1729
+ except ConnectionAbortedError :
1730
+ pass
1731
+ finally :
1732
+ sock .close ()
1733
+
1734
+ class ClientProto (asyncio .Protocol ):
1735
+ def __init__ (self , on_data , on_eof ):
1736
+ self .on_data = on_data
1737
+ self .on_eof = on_eof
1738
+ self .con_made_cnt = 0
1739
+
1740
+ def connection_made (proto , tr ):
1741
+ proto .con_made_cnt += 1
1742
+ # Ensure connection_made gets called only once.
1743
+ self .assertEqual (proto .con_made_cnt , 1 )
1744
+
1745
+ def data_received (self , data ):
1746
+ self .on_data .set_result (data )
1747
+
1748
+ def eof_received (self ):
1749
+ self .on_eof .set_result (True )
1750
+
1751
+ async def client (addr ):
1752
+ await asyncio .sleep (0.5 , loop = self .loop )
1753
+
1754
+ on_data = self .loop .create_future ()
1755
+ on_eof = self .loop .create_future ()
1756
+
1757
+ tr , proto = await self .loop .create_connection (
1758
+ lambda : ClientProto (on_data , on_eof ), * addr )
1759
+
1760
+ tr .write (HELLO_MSG )
1761
+
1762
+ await server_waits_on_handshake
1763
+
1764
+ with self .assertRaises (asyncio .TimeoutError ):
1765
+ await asyncio .wait_for (
1766
+ self .loop .start_tls (tr , proto , client_context ),
1767
+ 0.5 ,
1768
+ loop = self .loop )
1769
+
1770
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
1771
+ self .loop .run_until_complete (
1772
+ asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1773
+
1774
+ def test_start_tls_server_1 (self ):
1775
+ if self .implementation == 'asyncio' :
1776
+ raise unittest .SkipTest ()
1777
+
1778
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1779
+
1780
+ server_context = self ._create_server_ssl_context (
1781
+ self .ONLYCERT , self .ONLYKEY )
1782
+ client_context = self ._create_client_ssl_context ()
1783
+
1784
+ def client (sock , addr ):
1785
+ sock .settimeout (self .TIMEOUT )
1786
+
1787
+ sock .connect (addr )
1788
+ data = sock .recv_all (len (HELLO_MSG ))
1789
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1790
+
1791
+ sock .starttls (client_context )
1792
+ sock .sendall (HELLO_MSG )
1793
+
1794
+ sock .shutdown (socket .SHUT_RDWR )
1795
+ sock .close ()
1796
+
1797
+ class ServerProto (asyncio .Protocol ):
1798
+ def __init__ (self , on_con , on_eof , on_con_lost ):
1799
+ self .on_con = on_con
1800
+ self .on_eof = on_eof
1801
+ self .on_con_lost = on_con_lost
1802
+ self .data = b''
1803
+
1804
+ def connection_made (self , tr ):
1805
+ self .on_con .set_result (tr )
1806
+
1807
+ def data_received (self , data ):
1808
+ self .data += data
1809
+
1810
+ def eof_received (self ):
1811
+ self .on_eof .set_result (1 )
1812
+
1813
+ def connection_lost (self , exc ):
1814
+ if exc is None :
1815
+ self .on_con_lost .set_result (None )
1816
+ else :
1817
+ self .on_con_lost .set_exception (exc )
1818
+
1819
+ async def main (proto , on_con , on_eof , on_con_lost ):
1820
+ tr = await on_con
1821
+ tr .write (HELLO_MSG )
1822
+
1823
+ self .assertEqual (proto .data , b'' )
1824
+
1825
+ new_tr = await self .loop .start_tls (
1826
+ tr , proto , server_context ,
1827
+ server_side = True ,
1828
+ ssl_handshake_timeout = self .TIMEOUT )
1829
+
1830
+ await on_eof
1831
+ await on_con_lost
1832
+ self .assertEqual (proto .data , HELLO_MSG )
1833
+ new_tr .close ()
1834
+
1835
+ async def run_main ():
1836
+ on_con = self .loop .create_future ()
1837
+ on_eof = self .loop .create_future ()
1838
+ on_con_lost = self .loop .create_future ()
1839
+ proto = ServerProto (on_con , on_eof , on_con_lost )
1840
+
1841
+ server = await self .loop .create_server (
1842
+ lambda : proto , '127.0.0.1' , 0 )
1843
+ addr = server .sockets [0 ].getsockname ()
1844
+
1845
+ with self .tcp_client (lambda sock : client (sock , addr ),
1846
+ timeout = self .TIMEOUT ):
1847
+ await asyncio .wait_for (
1848
+ main (proto , on_con , on_eof , on_con_lost ),
1849
+ loop = self .loop , timeout = self .TIMEOUT )
1850
+
1851
+ server .close ()
1852
+ await server .wait_closed ()
1853
+
1854
+ self .loop .run_until_complete (run_main ())
1855
+
1483
1856
1484
1857
class Test_UV_TCPSSL (_TestSSL , tb .UVTestCase ):
1485
1858
pass
0 commit comments