Skip to content

Commit 686cd1f

Browse files
committed
Improve mocking of transport or socket layer.
1 parent f8ed666 commit 686cd1f

File tree

2 files changed

+24
-30
lines changed

2 files changed

+24
-30
lines changed

tests/asyncio/test_connection.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,21 +1118,28 @@ async def test_keepalive_reports_errors(self):
11181118

11191119
async def test_close_timeout(self):
11201120
"""close_timeout parameter configures close timeout."""
1121-
connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS)
1121+
connection = Connection(
1122+
Protocol(self.LOCAL),
1123+
close_timeout=42 * MS,
1124+
)
11221125
self.assertEqual(connection.close_timeout, 42 * MS)
11231126

11241127
async def test_max_queue(self):
11251128
"""max_queue configures high-water mark of frames buffer."""
1126-
connection = Connection(Protocol(self.LOCAL), max_queue=4)
1127-
transport = Mock()
1128-
connection.connection_made(transport)
1129+
connection = Connection(
1130+
Protocol(self.LOCAL),
1131+
max_queue=4,
1132+
)
1133+
connection.connection_made(Mock(spec=asyncio.Transport))
11291134
self.assertEqual(connection.recv_messages.high, 4)
11301135

11311136
async def test_max_queue_none(self):
11321137
"""max_queue disables high-water mark of frames buffer."""
1133-
connection = Connection(Protocol(self.LOCAL), max_queue=None)
1134-
transport = Mock()
1135-
connection.connection_made(transport)
1138+
connection = Connection(
1139+
Protocol(self.LOCAL),
1140+
max_queue=None,
1141+
)
1142+
connection.connection_made(Mock(spec=asyncio.Transport))
11361143
self.assertEqual(connection.recv_messages.high, None)
11371144
self.assertEqual(connection.recv_messages.low, None)
11381145

@@ -1142,8 +1149,7 @@ async def test_max_queue_tuple(self):
11421149
Protocol(self.LOCAL),
11431150
max_queue=(4, 2),
11441151
)
1145-
transport = Mock()
1146-
connection.connection_made(transport)
1152+
connection.connection_made(Mock(spec=asyncio.Transport))
11471153
self.assertEqual(connection.recv_messages.high, 4)
11481154
self.assertEqual(connection.recv_messages.low, 2)
11491155

@@ -1153,7 +1159,7 @@ async def test_write_limit(self):
11531159
Protocol(self.LOCAL),
11541160
write_limit=4096,
11551161
)
1156-
transport = Mock()
1162+
transport = Mock(spec=asyncio.Transport)
11571163
connection.connection_made(transport)
11581164
transport.set_write_buffer_limits.assert_called_once_with(4096, None)
11591165

@@ -1163,7 +1169,7 @@ async def test_write_limits(self):
11631169
Protocol(self.LOCAL),
11641170
write_limit=(4096, 2048),
11651171
)
1166-
transport = Mock()
1172+
transport = Mock(spec=asyncio.Transport)
11671173
connection.connection_made(transport)
11681174
transport.set_write_buffer_limits.assert_called_once_with(4096, 2048)
11691175

@@ -1177,13 +1183,13 @@ async def test_logger(self):
11771183
"""Connection has a logger attribute."""
11781184
self.assertIsInstance(self.connection.logger, logging.LoggerAdapter)
11791185

1180-
@patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234))
1186+
@patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234))
11811187
async def test_local_address(self, get_extra_info):
11821188
"""Connection provides a local_address attribute."""
11831189
self.assertEqual(self.connection.local_address, ("sock", 1234))
11841190
get_extra_info.assert_called_with("sockname")
11851191

1186-
@patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234))
1192+
@patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234))
11871193
async def test_remote_address(self, get_extra_info):
11881194
"""Connection provides a remote_address attribute."""
11891195
self.assertEqual(self.connection.remote_address, ("peer", 1234))

tests/sync/test_connection.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
import unittest
99
import uuid
10-
from unittest.mock import patch
10+
from unittest.mock import Mock, patch
1111

1212
from websockets.exceptions import (
1313
ConcurrencyError,
@@ -853,35 +853,26 @@ def test_keepalive_reports_errors(self):
853853

854854
def test_close_timeout(self):
855855
"""close_timeout parameter configures close timeout."""
856-
socket_, remote_socket = socket.socketpair()
857-
self.addCleanup(socket_.close)
858-
self.addCleanup(remote_socket.close)
859856
connection = Connection(
860-
socket_,
857+
Mock(spec=socket.socket),
861858
Protocol(self.LOCAL),
862859
close_timeout=42 * MS,
863860
)
864861
self.assertEqual(connection.close_timeout, 42 * MS)
865862

866863
def test_max_queue(self):
867864
"""max_queue configures high-water mark of frames buffer."""
868-
socket_, remote_socket = socket.socketpair()
869-
self.addCleanup(socket_.close)
870-
self.addCleanup(remote_socket.close)
871865
connection = Connection(
872-
socket_,
866+
Mock(spec=socket.socket),
873867
Protocol(self.LOCAL),
874868
max_queue=4,
875869
)
876870
self.assertEqual(connection.recv_messages.high, 4)
877871

878872
def test_max_queue_none(self):
879873
"""max_queue disables high-water mark of frames buffer."""
880-
socket_, remote_socket = socket.socketpair()
881-
self.addCleanup(socket_.close)
882-
self.addCleanup(remote_socket.close)
883874
connection = Connection(
884-
socket_,
875+
Mock(spec=socket.socket),
885876
Protocol(self.LOCAL),
886877
max_queue=None,
887878
)
@@ -890,11 +881,8 @@ def test_max_queue_none(self):
890881

891882
def test_max_queue_tuple(self):
892883
"""max_queue configures high-water and low-water marks of frames buffer."""
893-
socket_, remote_socket = socket.socketpair()
894-
self.addCleanup(socket_.close)
895-
self.addCleanup(remote_socket.close)
896884
connection = Connection(
897-
socket_,
885+
Mock(spec=socket.socket),
898886
Protocol(self.LOCAL),
899887
max_queue=(4, 2),
900888
)

0 commit comments

Comments
 (0)