Skip to content

Commit fc1f2cf

Browse files
committed
Clean up the asyncio and sync connection tests.
1 parent 46ad4bb commit fc1f2cf

File tree

2 files changed

+140
-124
lines changed

2 files changed

+140
-124
lines changed

tests/asyncio/test_connection.py

Lines changed: 92 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ async def test_aexit(self):
112112
await self.assertNoFrameSent()
113113
await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8"))
114114

115-
async def test_exit_with_exception(self):
116-
"""__exit__ with an exception closes the connection with code 1011."""
115+
async def test_aexit_with_exception(self):
116+
"""__aexit__ with an exception closes the connection with code 1011."""
117117
with self.assertRaises(RuntimeError):
118118
async with self.connection:
119119
raise RuntimeError
@@ -248,10 +248,9 @@ async def test_recv_during_recv_streaming(self):
248248
)
249249

250250
async def test_recv_cancellation_before_receiving(self):
251-
"""recv can be canceled before receiving a frame."""
251+
"""recv can be canceled before receiving a message."""
252252
recv_task = asyncio.create_task(self.connection.recv())
253253
await asyncio.sleep(0) # let the event loop start recv_task
254-
255254
recv_task.cancel()
256255
await asyncio.sleep(0) # let the event loop cancel recv_task
257256

@@ -260,25 +259,25 @@ async def test_recv_cancellation_before_receiving(self):
260259
self.assertEqual(await self.connection.recv(), "😀")
261260

262261
async def test_recv_cancellation_while_receiving(self):
263-
"""recv cannot be canceled after receiving a frame."""
264-
recv_task = asyncio.create_task(self.connection.recv())
265-
await asyncio.sleep(0) # let the event loop start recv_task
266-
267-
gate = asyncio.get_running_loop().create_future()
262+
"""recv can be canceled while receiving a fragmented message."""
263+
gate = asyncio.Event()
268264

269265
async def fragments():
270266
yield "⏳"
271-
await gate
267+
await gate.wait()
272268
yield "⌛️"
273269

274270
asyncio.create_task(self.remote_connection.send(fragments()))
275-
await asyncio.sleep(MS)
271+
await asyncio.sleep(0)
276272

273+
recv_task = asyncio.create_task(self.connection.recv())
274+
await asyncio.sleep(0) # let the event loop start recv_task
277275
recv_task.cancel()
278276
await asyncio.sleep(0) # let the event loop cancel recv_task
279277

278+
gate.set()
279+
280280
# Running recv again receives the complete message.
281-
gate.set_result(None)
282281
self.assertEqual(await self.connection.recv(), "⏳⌛️")
283282

284283
# Test recv_streaming.
@@ -407,28 +406,31 @@ async def test_recv_streaming_cancellation_before_receiving(self):
407406

408407
async def test_recv_streaming_cancellation_while_receiving(self):
409408
"""recv_streaming cannot be canceled after receiving a frame."""
410-
recv_streaming_task = asyncio.create_task(
411-
alist(self.connection.recv_streaming())
412-
)
413-
await asyncio.sleep(0) # let the event loop start recv_streaming_task
414-
415-
gate = asyncio.get_running_loop().create_future()
409+
gate = asyncio.Event()
416410

417411
async def fragments():
418412
yield "⏳"
419-
await gate
413+
await gate.wait()
420414
yield "⌛️"
421415

422416
asyncio.create_task(self.remote_connection.send(fragments()))
423-
await asyncio.sleep(MS)
417+
await asyncio.sleep(0)
424418

419+
recv_streaming_task = asyncio.create_task(
420+
alist(self.connection.recv_streaming())
421+
)
422+
await asyncio.sleep(0) # let the event loop start recv_streaming_task
423+
await asyncio.sleep(0) # experimentally, two runs of the event loop
424+
await asyncio.sleep(0) # are needed to receive the first fragment
425425
recv_streaming_task.cancel()
426426
await asyncio.sleep(0) # let the event loop cancel recv_streaming_task
427427

428-
gate.set_result(None)
428+
gate.set()
429+
429430
# Running recv_streaming again fails.
430431
with self.assertRaises(ConcurrencyError):
431-
await alist(self.connection.recv_streaming())
432+
async for _ in self.connection.recv_streaming():
433+
self.fail("did not raise")
432434

433435
# Test send.
434436

@@ -556,23 +558,31 @@ async def test_send_connection_closed_error(self):
556558
with self.assertRaises(ConnectionClosedError):
557559
await self.connection.send("😀")
558560

559-
async def test_send_while_send_blocked(self):
561+
async def test_send_during_send(self):
560562
"""send waits for a previous call to send to complete."""
561563
# This test fails if the guard with send_in_progress is removed
562-
# from send() in the case when message is an Iterable.
563-
self.connection.pause_writing()
564-
asyncio.create_task(self.connection.send(["⏳", "⌛️"]))
565-
await asyncio.sleep(MS)
564+
# from send() in the case when message is an AsyncIterable.
565+
gate = asyncio.Event()
566+
567+
async def fragments():
568+
yield "⏳"
569+
await gate.wait()
570+
yield "⌛️"
571+
572+
asyncio.create_task(self.connection.send(fragments()))
573+
await asyncio.sleep(0) # let the event loop start the task
566574
await self.assertFrameSent(
567575
Frame(Opcode.TEXT, "⏳".encode(), fin=False),
568576
)
569577

570578
asyncio.create_task(self.connection.send("✅"))
571-
await asyncio.sleep(MS)
579+
await asyncio.sleep(0) # let the event loop start the task
572580
await self.assertNoFrameSent()
573581

574-
self.connection.resume_writing()
575-
await asyncio.sleep(MS)
582+
gate.set()
583+
await asyncio.sleep(0) # run the event loop
584+
await asyncio.sleep(0) # three times in order
585+
await asyncio.sleep(0) # to send three frames
576586
await self.assertFramesSent(
577587
[
578588
Frame(Opcode.CONT, "⌛️".encode(), fin=False),
@@ -581,28 +591,26 @@ async def test_send_while_send_blocked(self):
581591
]
582592
)
583593

584-
async def test_send_while_send_async_blocked(self):
585-
"""send waits for a previous call to send to complete."""
594+
async def test_send_while_send_blocked(self):
595+
"""send waits for a blocked call to send to complete."""
586596
# This test fails if the guard with send_in_progress is removed
587-
# from send() in the case when message is an AsyncIterable.
597+
# from send() in the case when message is an Iterable.
588598
self.connection.pause_writing()
589599

590-
async def fragments():
591-
yield "⏳"
592-
yield "⌛️"
593-
594-
asyncio.create_task(self.connection.send(fragments()))
595-
await asyncio.sleep(MS)
600+
asyncio.create_task(self.connection.send(["⏳", "⌛️"]))
601+
await asyncio.sleep(0) # let the event loop start the task
596602
await self.assertFrameSent(
597603
Frame(Opcode.TEXT, "⏳".encode(), fin=False),
598604
)
599605

600606
asyncio.create_task(self.connection.send("✅"))
601-
await asyncio.sleep(MS)
607+
await asyncio.sleep(0) # let the event loop start the task
602608
await self.assertNoFrameSent()
603609

604610
self.connection.resume_writing()
605-
await asyncio.sleep(MS)
611+
await asyncio.sleep(0) # run the event loop
612+
await asyncio.sleep(0) # three times in order
613+
await asyncio.sleep(0) # to send three frames
606614
await self.assertFramesSent(
607615
[
608616
Frame(Opcode.CONT, "⌛️".encode(), fin=False),
@@ -611,29 +619,30 @@ async def fragments():
611619
]
612620
)
613621

614-
async def test_send_during_send_async(self):
615-
"""send waits for a previous call to send to complete."""
622+
async def test_send_while_send_async_blocked(self):
623+
"""send waits for a blocked call to send to complete."""
616624
# This test fails if the guard with send_in_progress is removed
617625
# from send() in the case when message is an AsyncIterable.
618-
gate = asyncio.get_running_loop().create_future()
626+
self.connection.pause_writing()
619627

620628
async def fragments():
621629
yield "⏳"
622-
await gate
623630
yield "⌛️"
624631

625632
asyncio.create_task(self.connection.send(fragments()))
626-
await asyncio.sleep(MS)
633+
await asyncio.sleep(0) # let the event loop start the task
627634
await self.assertFrameSent(
628635
Frame(Opcode.TEXT, "⏳".encode(), fin=False),
629636
)
630637

631638
asyncio.create_task(self.connection.send("✅"))
632-
await asyncio.sleep(MS)
639+
await asyncio.sleep(0) # let the event loop start the task
633640
await self.assertNoFrameSent()
634641

635-
gate.set_result(None)
636-
await asyncio.sleep(MS)
642+
self.connection.resume_writing()
643+
await asyncio.sleep(0) # run the event loop
644+
await asyncio.sleep(0) # three times in order
645+
await asyncio.sleep(0) # to send three frames
637646
await self.assertFramesSent(
638647
[
639648
Frame(Opcode.CONT, "⌛️".encode(), fin=False),
@@ -837,13 +846,9 @@ async def test_close_preserves_queued_messages(self):
837846
await self.connection.close()
838847

839848
self.assertEqual(await self.connection.recv(), "😀")
840-
with self.assertRaises(ConnectionClosedOK) as raised:
849+
with self.assertRaises(ConnectionClosedOK):
841850
await self.connection.recv()
842851

843-
exc = raised.exception
844-
self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)")
845-
self.assertIsNone(exc.__cause__)
846-
847852
async def test_close_idempotency(self):
848853
"""close does nothing if the connection is already closed."""
849854
await self.connection.close()
@@ -854,35 +859,42 @@ async def test_close_idempotency(self):
854859

855860
async def test_close_during_recv(self):
856861
"""close aborts recv when called concurrently with recv."""
857-
recv_task = asyncio.create_task(self.connection.recv())
858-
await asyncio.sleep(MS)
859-
await self.connection.close()
862+
863+
async def closer():
864+
await asyncio.sleep(MS)
865+
await self.connection.close()
866+
867+
asyncio.create_task(closer())
868+
860869
with self.assertRaises(ConnectionClosedOK) as raised:
861-
await recv_task
870+
await self.connection.recv()
862871

863872
exc = raised.exception
864873
self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)")
865874
self.assertIsNone(exc.__cause__)
866875

867876
async def test_close_during_send(self):
868877
"""close fails the connection when called concurrently with send."""
869-
gate = asyncio.get_running_loop().create_future()
878+
close_gate = asyncio.Event()
879+
exit_gate = asyncio.Event()
880+
881+
async def closer():
882+
await close_gate.wait()
883+
await self.connection.close()
884+
exit_gate.set()
870885

871886
async def fragments():
872887
yield "⏳"
873-
await gate
888+
close_gate.set()
889+
await exit_gate.wait()
874890
yield "⌛️"
875891

876-
send_task = asyncio.create_task(self.connection.send(fragments()))
877-
await asyncio.sleep(MS)
878-
879-
asyncio.create_task(self.connection.close())
880-
await asyncio.sleep(MS)
881-
882-
gate.set_result(None)
892+
asyncio.create_task(closer())
883893

884-
with self.assertRaises(ConnectionClosedError) as raised:
885-
await send_task
894+
iterator = fragments()
895+
async with contextlib.aclosing(iterator):
896+
with self.assertRaises(ConnectionClosedError) as raised:
897+
await self.connection.send(iterator)
886898

887899
exc = raised.exception
888900
self.assertEqual(
@@ -1050,7 +1062,7 @@ async def test_keepalive_times_out(self, getrandbits):
10501062
# 4.x ms: a pong frame is dropped.
10511063
await asyncio.sleep(5 * MS)
10521064
# 6 ms: no pong frame is received; the connection is closed.
1053-
await asyncio.sleep(2 * MS)
1065+
await asyncio.sleep(3 * MS)
10541066
# 7 ms: check that the connection is closed.
10551067
self.assertEqual(self.connection.state, State.CLOSED)
10561068

@@ -1066,14 +1078,15 @@ async def test_keepalive_ignores_timeout(self, getrandbits):
10661078
# 4.x ms: a pong frame is dropped.
10671079
await asyncio.sleep(5 * MS)
10681080
# 6 ms: no pong frame is received; the connection remains open.
1069-
await asyncio.sleep(2 * MS)
1081+
await asyncio.sleep(3 * MS)
10701082
# 7 ms: check that the connection is still open.
10711083
self.assertEqual(self.connection.state, State.OPEN)
10721084

10731085
async def test_keepalive_terminates_while_sleeping(self):
10741086
"""keepalive task terminates while waiting to send a ping."""
10751087
self.connection.ping_interval = 3 * MS
10761088
self.connection.start_keepalive()
1089+
self.assertFalse(self.connection.keepalive_task.done())
10771090
await asyncio.sleep(MS)
10781091
self.assertFalse(self.connection.keepalive_task.done())
10791092
await self.connection.close()
@@ -1231,9 +1244,7 @@ async def test_writing_in_data_received_fails(self):
12311244
# The connection closed exception reports the injected fault.
12321245
with self.assertRaises(ConnectionClosedError) as raised:
12331246
await self.connection.recv()
1234-
cause = raised.exception.__cause__
1235-
self.assertEqual(str(cause), "Cannot call write() after write_eof()")
1236-
self.assertIsInstance(cause, RuntimeError)
1247+
self.assertIsInstance(raised.exception.__cause__, RuntimeError)
12371248

12381249
async def test_writing_in_send_context_fails(self):
12391250
"""Error when sending outgoing frame is correctly reported."""
@@ -1244,9 +1255,7 @@ async def test_writing_in_send_context_fails(self):
12441255
# The connection closed exception reports the injected fault.
12451256
with self.assertRaises(ConnectionClosedError) as raised:
12461257
await self.connection.pong()
1247-
cause = raised.exception.__cause__
1248-
self.assertEqual(str(cause), "Cannot call write() after write_eof()")
1249-
self.assertIsInstance(cause, RuntimeError)
1258+
self.assertIsInstance(raised.exception.__cause__, RuntimeError)
12501259

12511260
# Test safety nets — catching all exceptions in case of bugs.
12521261

@@ -1336,11 +1345,11 @@ async def test_broadcast_skips_closing_connection(self):
13361345

13371346
async def test_broadcast_skips_connection_with_send_blocked(self):
13381347
"""broadcast logs a warning when a connection is blocked in send."""
1339-
gate = asyncio.get_running_loop().create_future()
1348+
gate = asyncio.Event()
13401349

13411350
async def fragments():
13421351
yield "⏳"
1343-
await gate
1352+
await gate.wait()
13441353

13451354
send_task = asyncio.create_task(self.connection.send(fragments()))
13461355
await asyncio.sleep(MS)
@@ -1354,7 +1363,7 @@ async def fragments():
13541363
["skipped broadcast: sending a fragmented message"],
13551364
)
13561365

1357-
gate.set_result(None)
1366+
gate.set()
13581367
await send_task
13591368

13601369
@unittest.skipIf(
@@ -1363,11 +1372,11 @@ async def fragments():
13631372
)
13641373
async def test_broadcast_reports_connection_with_send_blocked(self):
13651374
"""broadcast raises exceptions for connections blocked in send."""
1366-
gate = asyncio.get_running_loop().create_future()
1375+
gate = asyncio.Event()
13671376

13681377
async def fragments():
13691378
yield "⏳"
1370-
await gate
1379+
await gate.wait()
13711380

13721381
send_task = asyncio.create_task(self.connection.send(fragments()))
13731382
await asyncio.sleep(MS)
@@ -1381,7 +1390,7 @@ async def fragments():
13811390
self.assertEqual(str(exc), "sending a fragmented message")
13821391
self.assertIsInstance(exc, ConcurrencyError)
13831392

1384-
gate.set_result(None)
1393+
gate.set()
13851394
await send_task
13861395

13871396
async def test_broadcast_skips_connection_failing_to_send(self):

0 commit comments

Comments
 (0)