|
20 | 20 | from ..protocol import RecordingProtocol |
21 | 21 | from ..utils import MS |
22 | 22 | from .connection import InterceptingConnection |
| 23 | +from .utils import ThreadTestCase |
23 | 24 |
|
24 | 25 |
|
25 | 26 | # Connection implements symmetrical behavior between clients and servers. |
26 | 27 | # All tests run on the client side and the server side to validate this. |
27 | 28 |
|
28 | 29 |
|
29 | | -class ClientConnectionTests(unittest.TestCase): |
| 30 | +class ClientConnectionTests(ThreadTestCase): |
30 | 31 | LOCAL = CLIENT |
31 | 32 | REMOTE = SERVER |
32 | 33 |
|
@@ -196,38 +197,28 @@ def test_recv_non_utf8_text(self): |
196 | 197 |
|
197 | 198 | def test_recv_during_recv(self): |
198 | 199 | """recv raises ConcurrencyError when called concurrently.""" |
199 | | - recv_thread = threading.Thread(target=self.connection.recv) |
200 | | - recv_thread.start() |
201 | | - |
202 | | - with self.assertRaises(ConcurrencyError) as raised: |
203 | | - self.connection.recv() |
| 200 | + with self.run_in_thread(self.connection.recv): |
| 201 | + with self.assertRaises(ConcurrencyError) as raised: |
| 202 | + self.connection.recv() |
| 203 | + self.remote_connection.send("") |
204 | 204 | self.assertEqual( |
205 | 205 | str(raised.exception), |
206 | 206 | "cannot call recv while another thread " |
207 | 207 | "is already running recv or recv_streaming", |
208 | 208 | ) |
209 | 209 |
|
210 | | - self.remote_connection.send("") |
211 | | - recv_thread.join() |
212 | | - |
213 | 210 | def test_recv_during_recv_streaming(self): |
214 | 211 | """recv raises ConcurrencyError when called concurrently with recv_streaming.""" |
215 | | - recv_streaming_thread = threading.Thread( |
216 | | - target=lambda: list(self.connection.recv_streaming()) |
217 | | - ) |
218 | | - recv_streaming_thread.start() |
219 | | - |
220 | | - with self.assertRaises(ConcurrencyError) as raised: |
221 | | - self.connection.recv() |
| 212 | + with self.run_in_thread(lambda: list(self.connection.recv_streaming())): |
| 213 | + with self.assertRaises(ConcurrencyError) as raised: |
| 214 | + self.connection.recv() |
| 215 | + self.remote_connection.send("") |
222 | 216 | self.assertEqual( |
223 | 217 | str(raised.exception), |
224 | 218 | "cannot call recv while another thread " |
225 | 219 | "is already running recv or recv_streaming", |
226 | 220 | ) |
227 | 221 |
|
228 | | - self.remote_connection.send("") |
229 | | - recv_streaming_thread.join() |
230 | | - |
231 | 222 | # Test recv_streaming. |
232 | 223 |
|
233 | 224 | def test_recv_streaming_text(self): |
@@ -305,40 +296,30 @@ def test_recv_streaming_non_utf8_text(self): |
305 | 296 |
|
306 | 297 | def test_recv_streaming_during_recv(self): |
307 | 298 | """recv_streaming raises ConcurrencyError when called concurrently with recv.""" |
308 | | - recv_thread = threading.Thread(target=self.connection.recv) |
309 | | - recv_thread.start() |
310 | | - |
311 | | - with self.assertRaises(ConcurrencyError) as raised: |
312 | | - for _ in self.connection.recv_streaming(): |
313 | | - self.fail("did not raise") |
| 299 | + with self.run_in_thread(self.connection.recv): |
| 300 | + with self.assertRaises(ConcurrencyError) as raised: |
| 301 | + for _ in self.connection.recv_streaming(): |
| 302 | + self.fail("did not raise") |
| 303 | + self.remote_connection.send("") |
314 | 304 | self.assertEqual( |
315 | 305 | str(raised.exception), |
316 | 306 | "cannot call recv_streaming while another thread " |
317 | 307 | "is already running recv or recv_streaming", |
318 | 308 | ) |
319 | 309 |
|
320 | | - self.remote_connection.send("") |
321 | | - recv_thread.join() |
322 | | - |
323 | 310 | def test_recv_streaming_during_recv_streaming(self): |
324 | 311 | """recv_streaming raises ConcurrencyError when called concurrently.""" |
325 | | - recv_streaming_thread = threading.Thread( |
326 | | - target=lambda: list(self.connection.recv_streaming()) |
327 | | - ) |
328 | | - recv_streaming_thread.start() |
329 | | - |
330 | | - with self.assertRaises(ConcurrencyError) as raised: |
331 | | - for _ in self.connection.recv_streaming(): |
332 | | - self.fail("did not raise") |
| 312 | + with self.run_in_thread(lambda: list(self.connection.recv_streaming())): |
| 313 | + with self.assertRaises(ConcurrencyError) as raised: |
| 314 | + for _ in self.connection.recv_streaming(): |
| 315 | + self.fail("did not raise") |
| 316 | + self.remote_connection.send("") |
333 | 317 | self.assertEqual( |
334 | 318 | str(raised.exception), |
335 | 319 | r"cannot call recv_streaming while another thread " |
336 | 320 | r"is already running recv or recv_streaming", |
337 | 321 | ) |
338 | 322 |
|
339 | | - self.remote_connection.send("") |
340 | | - recv_streaming_thread.join() |
341 | | - |
342 | 323 | # Test send. |
343 | 324 |
|
344 | 325 | def test_send_text(self): |
@@ -411,43 +392,40 @@ def test_send_connection_closed_error(self): |
411 | 392 |
|
412 | 393 | def test_send_during_send(self): |
413 | 394 | """send raises ConcurrencyError when called concurrently.""" |
414 | | - recv_thread = threading.Thread(target=self.remote_connection.recv) |
415 | | - recv_thread.start() |
416 | | - |
417 | | - send_gate = threading.Event() |
418 | | - exit_gate = threading.Event() |
419 | | - |
420 | | - def fragments(): |
421 | | - yield "😀" |
422 | | - send_gate.set() |
423 | | - exit_gate.wait() |
424 | | - yield "😀" |
| 395 | + with self.run_in_thread(self.remote_connection.recv): |
| 396 | + send_gate = threading.Event() |
| 397 | + exit_gate = threading.Event() |
| 398 | + |
| 399 | + def fragments(): |
| 400 | + yield "😀" |
| 401 | + send_gate.set() |
| 402 | + exit_gate.wait() |
| 403 | + yield "😀" |
| 404 | + |
| 405 | + send_thread = threading.Thread( |
| 406 | + target=self.connection.send, |
| 407 | + args=(fragments(),), |
| 408 | + ) |
| 409 | + send_thread.start() |
| 410 | + |
| 411 | + send_gate.wait() |
| 412 | + # The check happens in four code paths, depending on the argument. |
| 413 | + for message in [ |
| 414 | + "😀", |
| 415 | + b"\x01\x02\xfe\xff", |
| 416 | + ["😀", "😀"], |
| 417 | + [b"\x01\x02", b"\xfe\xff"], |
| 418 | + ]: |
| 419 | + with self.subTest(message=message): |
| 420 | + with self.assertRaises(ConcurrencyError) as raised: |
| 421 | + self.connection.send(message) |
| 422 | + self.assertEqual( |
| 423 | + str(raised.exception), |
| 424 | + "cannot call send while another thread is already running send", |
| 425 | + ) |
425 | 426 |
|
426 | | - send_thread = threading.Thread( |
427 | | - target=self.connection.send, |
428 | | - args=(fragments(),), |
429 | | - ) |
430 | | - send_thread.start() |
431 | | - |
432 | | - send_gate.wait() |
433 | | - # The check happens in four code paths, depending on the argument. |
434 | | - for message in [ |
435 | | - "😀", |
436 | | - b"\x01\x02\xfe\xff", |
437 | | - ["😀", "😀"], |
438 | | - [b"\x01\x02", b"\xfe\xff"], |
439 | | - ]: |
440 | | - with self.subTest(message=message): |
441 | | - with self.assertRaises(ConcurrencyError) as raised: |
442 | | - self.connection.send(message) |
443 | | - self.assertEqual( |
444 | | - str(raised.exception), |
445 | | - "cannot call send while another thread is already running send", |
446 | | - ) |
447 | | - |
448 | | - exit_gate.set() |
449 | | - send_thread.join() |
450 | | - recv_thread.join() |
| 427 | + exit_gate.set() |
| 428 | + send_thread.join() |
451 | 429 |
|
452 | 430 | def test_send_empty_iterable(self): |
453 | 431 | """send does nothing when called with an empty iterable.""" |
@@ -571,45 +549,31 @@ def closer(): |
571 | 549 | with self.delay_frames_rcvd(4 * MS): |
572 | 550 | self.connection.close() |
573 | 551 |
|
574 | | - close_thread = threading.Thread(target=closer) |
575 | | - close_thread.start() |
576 | | - |
577 | | - # Let closer() initiate the closing handshake and send a close frame. |
578 | | - time.sleep(MS) |
579 | | - self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) |
| 552 | + with self.run_in_thread(closer): |
| 553 | + # run_in_thread() waits for MS, which lets closer() send a close frame. |
| 554 | + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) |
580 | 555 |
|
581 | | - # Connection isn't closed yet. |
582 | | - with self.assertRaises(TimeoutError): |
583 | | - self.connection.recv(timeout=MS) |
| 556 | + # Connection isn't closed yet. |
| 557 | + with self.assertRaises(TimeoutError): |
| 558 | + self.connection.recv(timeout=MS) |
584 | 559 |
|
585 | | - self.connection.close() |
586 | | - self.assertNoFrameSent() |
587 | | - |
588 | | - # Connection is closed now. |
589 | | - with self.assertRaises(ConnectionClosedOK): |
590 | | - self.connection.recv(timeout=MS) |
| 560 | + self.connection.close() |
| 561 | + self.assertNoFrameSent() |
591 | 562 |
|
592 | | - close_thread.join() |
| 563 | + # Connection is closed now. |
| 564 | + with self.assertRaises(ConnectionClosedOK): |
| 565 | + self.connection.recv(timeout=MS) |
593 | 566 |
|
594 | 567 | def test_close_during_recv(self): |
595 | 568 | """close aborts recv when called concurrently with recv.""" |
596 | | - |
597 | | - def closer(): |
598 | | - time.sleep(MS) |
599 | | - self.connection.close() |
600 | | - |
601 | | - close_thread = threading.Thread(target=closer) |
602 | | - close_thread.start() |
603 | | - |
604 | | - with self.assertRaises(ConnectionClosedOK) as raised: |
605 | | - self.connection.recv() |
| 569 | + with self.run_in_thread(self.connection.close): |
| 570 | + with self.assertRaises(ConnectionClosedOK) as raised: |
| 571 | + self.connection.recv() |
606 | 572 |
|
607 | 573 | exc = raised.exception |
608 | 574 | self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") |
609 | 575 | self.assertIsNone(exc.__cause__) |
610 | 576 |
|
611 | | - close_thread.join() |
612 | | - |
613 | 577 | def test_close_during_send(self): |
614 | 578 | """close fails the connection when called concurrently with send.""" |
615 | 579 | close_gate = threading.Event() |
|
0 commit comments