Skip to content

Commit 2e6d793

Browse files
committed
Eliminate race conditions in sync connection tests.
These tests started two threads and relied on the first one going far enough by the time the second one starts. run_in_thread() introduces a delay to ensure the operations happen in the expected order, in addition to making the tests shorter.
1 parent 13c0d99 commit 2e6d793

File tree

1 file changed

+67
-103
lines changed

1 file changed

+67
-103
lines changed

tests/sync/test_connection.py

Lines changed: 67 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from ..protocol import RecordingProtocol
2121
from ..utils import MS
2222
from .connection import InterceptingConnection
23+
from .utils import ThreadTestCase
2324

2425

2526
# Connection implements symmetrical behavior between clients and servers.
2627
# All tests run on the client side and the server side to validate this.
2728

2829

29-
class ClientConnectionTests(unittest.TestCase):
30+
class ClientConnectionTests(ThreadTestCase):
3031
LOCAL = CLIENT
3132
REMOTE = SERVER
3233

@@ -196,38 +197,28 @@ def test_recv_non_utf8_text(self):
196197

197198
def test_recv_during_recv(self):
198199
"""recv raises ConcurrencyError when called concurrently."""
199-
recv_thread = threading.Thread(target=self.connection.recv)
200-
recv_thread.start()
201-
202-
with self.assertRaises(ConcurrencyError) as raised:
203-
self.connection.recv()
200+
with self.run_in_thread(self.connection.recv):
201+
with self.assertRaises(ConcurrencyError) as raised:
202+
self.connection.recv()
203+
self.remote_connection.send("")
204204
self.assertEqual(
205205
str(raised.exception),
206206
"cannot call recv while another thread "
207207
"is already running recv or recv_streaming",
208208
)
209209

210-
self.remote_connection.send("")
211-
recv_thread.join()
212-
213210
def test_recv_during_recv_streaming(self):
214211
"""recv raises ConcurrencyError when called concurrently with recv_streaming."""
215-
recv_streaming_thread = threading.Thread(
216-
target=lambda: list(self.connection.recv_streaming())
217-
)
218-
recv_streaming_thread.start()
219-
220-
with self.assertRaises(ConcurrencyError) as raised:
221-
self.connection.recv()
212+
with self.run_in_thread(lambda: list(self.connection.recv_streaming())):
213+
with self.assertRaises(ConcurrencyError) as raised:
214+
self.connection.recv()
215+
self.remote_connection.send("")
222216
self.assertEqual(
223217
str(raised.exception),
224218
"cannot call recv while another thread "
225219
"is already running recv or recv_streaming",
226220
)
227221

228-
self.remote_connection.send("")
229-
recv_streaming_thread.join()
230-
231222
# Test recv_streaming.
232223

233224
def test_recv_streaming_text(self):
@@ -305,40 +296,30 @@ def test_recv_streaming_non_utf8_text(self):
305296

306297
def test_recv_streaming_during_recv(self):
307298
"""recv_streaming raises ConcurrencyError when called concurrently with recv."""
308-
recv_thread = threading.Thread(target=self.connection.recv)
309-
recv_thread.start()
310-
311-
with self.assertRaises(ConcurrencyError) as raised:
312-
for _ in self.connection.recv_streaming():
313-
self.fail("did not raise")
299+
with self.run_in_thread(self.connection.recv):
300+
with self.assertRaises(ConcurrencyError) as raised:
301+
for _ in self.connection.recv_streaming():
302+
self.fail("did not raise")
303+
self.remote_connection.send("")
314304
self.assertEqual(
315305
str(raised.exception),
316306
"cannot call recv_streaming while another thread "
317307
"is already running recv or recv_streaming",
318308
)
319309

320-
self.remote_connection.send("")
321-
recv_thread.join()
322-
323310
def test_recv_streaming_during_recv_streaming(self):
324311
"""recv_streaming raises ConcurrencyError when called concurrently."""
325-
recv_streaming_thread = threading.Thread(
326-
target=lambda: list(self.connection.recv_streaming())
327-
)
328-
recv_streaming_thread.start()
329-
330-
with self.assertRaises(ConcurrencyError) as raised:
331-
for _ in self.connection.recv_streaming():
332-
self.fail("did not raise")
312+
with self.run_in_thread(lambda: list(self.connection.recv_streaming())):
313+
with self.assertRaises(ConcurrencyError) as raised:
314+
for _ in self.connection.recv_streaming():
315+
self.fail("did not raise")
316+
self.remote_connection.send("")
333317
self.assertEqual(
334318
str(raised.exception),
335319
r"cannot call recv_streaming while another thread "
336320
r"is already running recv or recv_streaming",
337321
)
338322

339-
self.remote_connection.send("")
340-
recv_streaming_thread.join()
341-
342323
# Test send.
343324

344325
def test_send_text(self):
@@ -411,43 +392,40 @@ def test_send_connection_closed_error(self):
411392

412393
def test_send_during_send(self):
413394
"""send raises ConcurrencyError when called concurrently."""
414-
recv_thread = threading.Thread(target=self.remote_connection.recv)
415-
recv_thread.start()
416-
417-
send_gate = threading.Event()
418-
exit_gate = threading.Event()
419-
420-
def fragments():
421-
yield "😀"
422-
send_gate.set()
423-
exit_gate.wait()
424-
yield "😀"
395+
with self.run_in_thread(self.remote_connection.recv):
396+
send_gate = threading.Event()
397+
exit_gate = threading.Event()
398+
399+
def fragments():
400+
yield "😀"
401+
send_gate.set()
402+
exit_gate.wait()
403+
yield "😀"
404+
405+
send_thread = threading.Thread(
406+
target=self.connection.send,
407+
args=(fragments(),),
408+
)
409+
send_thread.start()
410+
411+
send_gate.wait()
412+
# The check happens in four code paths, depending on the argument.
413+
for message in [
414+
"😀",
415+
b"\x01\x02\xfe\xff",
416+
["😀", "😀"],
417+
[b"\x01\x02", b"\xfe\xff"],
418+
]:
419+
with self.subTest(message=message):
420+
with self.assertRaises(ConcurrencyError) as raised:
421+
self.connection.send(message)
422+
self.assertEqual(
423+
str(raised.exception),
424+
"cannot call send while another thread is already running send",
425+
)
425426

426-
send_thread = threading.Thread(
427-
target=self.connection.send,
428-
args=(fragments(),),
429-
)
430-
send_thread.start()
431-
432-
send_gate.wait()
433-
# The check happens in four code paths, depending on the argument.
434-
for message in [
435-
"😀",
436-
b"\x01\x02\xfe\xff",
437-
["😀", "😀"],
438-
[b"\x01\x02", b"\xfe\xff"],
439-
]:
440-
with self.subTest(message=message):
441-
with self.assertRaises(ConcurrencyError) as raised:
442-
self.connection.send(message)
443-
self.assertEqual(
444-
str(raised.exception),
445-
"cannot call send while another thread is already running send",
446-
)
447-
448-
exit_gate.set()
449-
send_thread.join()
450-
recv_thread.join()
427+
exit_gate.set()
428+
send_thread.join()
451429

452430
def test_send_empty_iterable(self):
453431
"""send does nothing when called with an empty iterable."""
@@ -571,45 +549,31 @@ def closer():
571549
with self.delay_frames_rcvd(4 * MS):
572550
self.connection.close()
573551

574-
close_thread = threading.Thread(target=closer)
575-
close_thread.start()
576-
577-
# Let closer() initiate the closing handshake and send a close frame.
578-
time.sleep(MS)
579-
self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8"))
552+
with self.run_in_thread(closer):
553+
# run_in_thread() waits for MS, which lets closer() send a close frame.
554+
self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8"))
580555

581-
# Connection isn't closed yet.
582-
with self.assertRaises(TimeoutError):
583-
self.connection.recv(timeout=MS)
556+
# Connection isn't closed yet.
557+
with self.assertRaises(TimeoutError):
558+
self.connection.recv(timeout=MS)
584559

585-
self.connection.close()
586-
self.assertNoFrameSent()
587-
588-
# Connection is closed now.
589-
with self.assertRaises(ConnectionClosedOK):
590-
self.connection.recv(timeout=MS)
560+
self.connection.close()
561+
self.assertNoFrameSent()
591562

592-
close_thread.join()
563+
# Connection is closed now.
564+
with self.assertRaises(ConnectionClosedOK):
565+
self.connection.recv(timeout=MS)
593566

594567
def test_close_during_recv(self):
595568
"""close aborts recv when called concurrently with recv."""
596-
597-
def closer():
598-
time.sleep(MS)
599-
self.connection.close()
600-
601-
close_thread = threading.Thread(target=closer)
602-
close_thread.start()
603-
604-
with self.assertRaises(ConnectionClosedOK) as raised:
605-
self.connection.recv()
569+
with self.run_in_thread(self.connection.close):
570+
with self.assertRaises(ConnectionClosedOK) as raised:
571+
self.connection.recv()
606572

607573
exc = raised.exception
608574
self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)")
609575
self.assertIsNone(exc.__cause__)
610576

611-
close_thread.join()
612-
613577
def test_close_during_send(self):
614578
"""close fails the connection when called concurrently with send."""
615579
close_gate = threading.Event()

0 commit comments

Comments
 (0)