Skip to content

Commit d4e3e26

Browse files
palkeoCoolCat467A5rocks
authored andcommitted
Add the ability to specify the buffer size. (#186)
Co-authored-by: CoolCat467 <[email protected]> Co-authored-by: A5rocks <[email protected]>
1 parent 49b93c1 commit d4e3e26

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

tests/test_connection.py

+1
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def test_client_open_url_options( # type: ignore[misc]
401401
'extra_headers': [(b'X-Test-Header', b'My test header')],
402402
'message_queue_size': 9,
403403
'max_message_size': 333,
404+
'receive_buffer_size': 999,
404405
'connect_timeout': 36,
405406
'disconnect_timeout': 37,
406407
}

trio_websocket/_impl.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ async def open_websocket(
117117
extra_headers: Optional[list[tuple[bytes,bytes]]] = None,
118118
message_queue_size: int = MESSAGE_QUEUE_SIZE,
119119
max_message_size: int = MAX_MESSAGE_SIZE,
120+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
120121
connect_timeout: float = CONN_TIMEOUT,
121122
disconnect_timeout: float = CONN_TIMEOUT
122123
) -> AsyncGenerator[WebSocketConnection, None]:
@@ -144,6 +145,9 @@ async def open_websocket(
144145
:param int max_message_size: The maximum message size as measured by
145146
``len()``. If a message is received that is larger than this size,
146147
then the connection is closed with code 1009 (Message Too Big).
148+
:param Optional[int] receive_buffer_size: The buffer size we use to
149+
receive messages internally. None to let trio choose. Defaults
150+
to 4 KiB.
147151
:param float connect_timeout: The number of seconds to wait for the
148152
connection before timing out.
149153
:param float disconnect_timeout: The number of seconds to wait when closing
@@ -182,7 +186,8 @@ async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
182186
resource, use_ssl=use_ssl, subprotocols=subprotocols,
183187
extra_headers=extra_headers,
184188
message_queue_size=message_queue_size,
185-
max_message_size=max_message_size)
189+
max_message_size=max_message_size,
190+
receive_buffer_size=receive_buffer_size)
186191
except trio.TooSlowError:
187192
raise ConnectionTimeout from None
188193
except OSError as e:
@@ -311,6 +316,7 @@ async def connect_websocket(
311316
extra_headers: list[tuple[bytes, bytes]] | None = None,
312317
message_queue_size: int = MESSAGE_QUEUE_SIZE,
313318
max_message_size: int = MAX_MESSAGE_SIZE,
319+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
314320
) -> WebSocketConnection:
315321
'''
316322
Return an open WebSocket client connection to a host.
@@ -339,6 +345,9 @@ async def connect_websocket(
339345
:param int max_message_size: The maximum message size as measured by
340346
``len()``. If a message is received that is larger than this size,
341347
then the connection is closed with code 1009 (Message Too Big).
348+
:param Optional[int] receive_buffer_size: The buffer size we use to
349+
receive messages internally. None to let trio choose. Defaults
350+
to 4 KiB.
342351
:rtype: WebSocketConnection
343352
'''
344353
if use_ssl is True:
@@ -368,7 +377,8 @@ async def connect_websocket(
368377
path=resource,
369378
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
370379
message_queue_size=message_queue_size,
371-
max_message_size=max_message_size)
380+
max_message_size=max_message_size,
381+
receive_buffer_size=receive_buffer_size)
372382
nursery.start_soon(connection._reader_task)
373383
await connection._open_handshake.wait()
374384
return connection
@@ -384,6 +394,7 @@ def open_websocket_url(
384394
max_message_size: int = MAX_MESSAGE_SIZE,
385395
connect_timeout: float = CONN_TIMEOUT,
386396
disconnect_timeout: float = CONN_TIMEOUT,
397+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
387398
) -> AbstractAsyncContextManager[WebSocketConnection]:
388399
'''
389400
Open a WebSocket client connection to a URL.
@@ -407,6 +418,9 @@ def open_websocket_url(
407418
:param int max_message_size: The maximum message size as measured by
408419
``len()``. If a message is received that is larger than this size,
409420
then the connection is closed with code 1009 (Message Too Big).
421+
:param Optional[int] receive_buffer_size: The buffer size we use to
422+
receive messages internally. None to let trio choose. Defaults
423+
to 4 KiB.
410424
:param float connect_timeout: The number of seconds to wait for the
411425
connection before timing out.
412426
:param float disconnect_timeout: The number of seconds to wait when closing
@@ -420,6 +434,7 @@ def open_websocket_url(
420434
subprotocols=subprotocols, extra_headers=extra_headers,
421435
message_queue_size=message_queue_size,
422436
max_message_size=max_message_size,
437+
receive_buffer_size=receive_buffer_size,
423438
connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout)
424439

425440

@@ -432,6 +447,7 @@ async def connect_websocket_url(
432447
extra_headers: list[tuple[bytes, bytes]] | None = None,
433448
message_queue_size: int = MESSAGE_QUEUE_SIZE,
434449
max_message_size: int = MAX_MESSAGE_SIZE,
450+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
435451
) -> WebSocketConnection:
436452
'''
437453
Return an open WebSocket client connection to a URL.
@@ -457,13 +473,17 @@ async def connect_websocket_url(
457473
:param int max_message_size: The maximum message size as measured by
458474
``len()``. If a message is received that is larger than this size,
459475
then the connection is closed with code 1009 (Message Too Big).
476+
:param Optional[int] receive_buffer_size: The buffer size we use to
477+
receive messages internally. None to let trio choose. Defaults
478+
to 4 KiB.
460479
:rtype: WebSocketConnection
461480
'''
462481
host, port, resource, return_ssl_context = _url_to_host(url, ssl_context)
463482
return await connect_websocket(nursery, host, port, resource,
464483
use_ssl=return_ssl_context, subprotocols=subprotocols,
465484
extra_headers=extra_headers, message_queue_size=message_queue_size,
466-
max_message_size=max_message_size)
485+
max_message_size=max_message_size,
486+
receive_buffer_size=receive_buffer_size)
467487

468488

469489
def _url_to_host(
@@ -520,6 +540,7 @@ async def wrap_client_stream(
520540
extra_headers: list[tuple[bytes, bytes]] | None = None,
521541
message_queue_size: int = MESSAGE_QUEUE_SIZE,
522542
max_message_size: int = MAX_MESSAGE_SIZE,
543+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
523544
) -> WebSocketConnection:
524545
'''
525546
Wrap an arbitrary stream in a WebSocket connection.
@@ -544,14 +565,18 @@ async def wrap_client_stream(
544565
:param int max_message_size: The maximum message size as measured by
545566
``len()``. If a message is received that is larger than this size,
546567
then the connection is closed with code 1009 (Message Too Big).
568+
:param Optional[int] receive_buffer_size: The buffer size we use to
569+
receive messages internally. None to let trio choose. Defaults
570+
to 4 KiB.
547571
:rtype: WebSocketConnection
548572
'''
549573
connection = WebSocketConnection(stream,
550574
WSConnection(ConnectionType.CLIENT),
551575
host=host, path=resource,
552576
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
553577
message_queue_size=message_queue_size,
554-
max_message_size=max_message_size)
578+
max_message_size=max_message_size,
579+
receive_buffer_size=receive_buffer_size)
555580
nursery.start_soon(connection._reader_task)
556581
await connection._open_handshake.wait()
557582
return connection
@@ -562,6 +587,7 @@ async def wrap_server_stream(
562587
stream: trio.abc.Stream,
563588
message_queue_size: int = MESSAGE_QUEUE_SIZE,
564589
max_message_size: int = MAX_MESSAGE_SIZE,
590+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
565591
) -> WebSocketRequest:
566592
'''
567593
Wrap an arbitrary stream in a server-side WebSocket.
@@ -576,19 +602,24 @@ async def wrap_server_stream(
576602
:param int max_message_size: The maximum message size as measured by
577603
``len()``. If a message is received that is larger than this size,
578604
then the connection is closed with code 1009 (Message Too Big).
605+
:param Optional[int] receive_buffer_size: The buffer size we use to
606+
receive messages internally. None to let trio choose. Defaults
607+
to 4 KiB.
579608
:type stream: trio.abc.Stream
580609
:rtype: WebSocketRequest
581610
'''
582611
connection = WebSocketConnection(
583612
stream,
584613
WSConnection(ConnectionType.SERVER),
585614
message_queue_size=message_queue_size,
586-
max_message_size=max_message_size)
615+
max_message_size=max_message_size,
616+
receive_buffer_size=receive_buffer_size)
587617
nursery.start_soon(connection._reader_task)
588618
request = await connection._get_request()
589619
return request
590620

591621

622+
592623
async def serve_websocket(
593624
handler: Callable[[WebSocketRequest], Awaitable[None]],
594625
host: str | bytes | None,
@@ -598,6 +629,7 @@ async def serve_websocket(
598629
handler_nursery: trio.Nursery | None = None,
599630
message_queue_size: int = MESSAGE_QUEUE_SIZE,
600631
max_message_size: int = MAX_MESSAGE_SIZE,
632+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
601633
connect_timeout: float = CONN_TIMEOUT,
602634
disconnect_timeout: float = CONN_TIMEOUT,
603635
task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED,
@@ -630,6 +662,9 @@ async def serve_websocket(
630662
:param int max_message_size: The maximum message size as measured by
631663
``len()``. If a message is received that is larger than this size,
632664
then the connection is closed with code 1009 (Message Too Big).
665+
:param Optional[int] receive_buffer_size: The buffer size we use to
666+
receive messages internally. None to let trio choose. Defaults
667+
to 4 KiB.
633668
:param float connect_timeout: The number of seconds to wait for a client
634669
to finish connection handshake before timing out.
635670
:param float disconnect_timeout: The number of seconds to wait for a client
@@ -658,6 +693,7 @@ async def serve_websocket(
658693
handler_nursery=handler_nursery,
659694
message_queue_size=message_queue_size,
660695
max_message_size=max_message_size,
696+
receive_buffer_size=receive_buffer_size,
661697
connect_timeout=connect_timeout,
662698
disconnect_timeout=disconnect_timeout,
663699
)
@@ -957,7 +993,8 @@ def __init__(
957993
client_subprotocols: Iterable[str] | None = None,
958994
client_extra_headers: list[tuple[bytes, bytes]] | None = None,
959995
message_queue_size: int = MESSAGE_QUEUE_SIZE,
960-
max_message_size: int = MAX_MESSAGE_SIZE
996+
max_message_size: int = MAX_MESSAGE_SIZE,
997+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
961998
) -> None:
962999
'''
9631000
Constructor.
@@ -984,6 +1021,9 @@ def __init__(
9841021
:param int max_message_size: The maximum message size as measured by
9851022
``len()``. If a message is received that is larger than this size,
9861023
then the connection is closed with code 1009 (Message Too Big).
1024+
:param Optional[int] receive_buffer_size: The buffer size we use to
1025+
receive messages internally. None to let trio choose. Defaults
1026+
to 4 KiB.
9871027
'''
9881028
# NOTE: The implementation uses _close_reason for more than an advisory
9891029
# purpose. It's critical internal state, indicating when the
@@ -996,6 +1036,7 @@ def __init__(
9961036
self._message_size = 0
9971037
self._message_parts: List[Union[bytes, str]] = []
9981038
self._max_message_size = max_message_size
1039+
self._receive_buffer_size: Optional[int] = receive_buffer_size
9991040
self._reader_running = True
10001041
if ws_connection.client:
10011042
assert host is not None
@@ -1528,7 +1569,7 @@ async def _reader_task(self) -> None:
15281569

15291570
# Get network data.
15301571
try:
1531-
data = await self._stream.receive_some(RECEIVE_BYTES)
1572+
data = await self._stream.receive_some(self._receive_buffer_size)
15321573
except (trio.BrokenResourceError, trio.ClosedResourceError):
15331574
await self._abort_web_socket()
15341575
break
@@ -1619,6 +1660,7 @@ def __init__(
16191660
handler_nursery: trio.Nursery | None = None,
16201661
message_queue_size: int = MESSAGE_QUEUE_SIZE,
16211662
max_message_size: int = MAX_MESSAGE_SIZE,
1663+
receive_buffer_size: Union[None, int] = RECEIVE_BYTES,
16221664
connect_timeout: float = CONN_TIMEOUT,
16231665
disconnect_timeout: float = CONN_TIMEOUT,
16241666
) -> None:
@@ -1637,6 +1679,9 @@ def __init__(
16371679
:param handler_nursery: An optional nursery to spawn connection tasks
16381680
inside of. If ``None``, then a new nursery will be created
16391681
internally.
1682+
:param Optional[int] receive_buffer_size: The buffer size we use to
1683+
receive messages internally. None to let trio choose. Defaults
1684+
to 4 KiB.
16401685
:param float connect_timeout: The number of seconds to wait for a client
16411686
to finish connection handshake before timing out.
16421687
:param float disconnect_timeout: The number of seconds to wait for a client
@@ -1649,6 +1694,7 @@ def __init__(
16491694
self._listeners = listeners
16501695
self._message_queue_size = message_queue_size
16511696
self._max_message_size = max_message_size
1697+
self._receive_buffer_size = receive_buffer_size
16521698
self._connect_timeout = connect_timeout
16531699
self._disconnect_timeout = disconnect_timeout
16541700

@@ -1741,7 +1787,8 @@ async def _handle_connection(self, stream: trio.abc.Stream) -> None:
17411787
connection = WebSocketConnection(stream,
17421788
WSConnection(ConnectionType.SERVER),
17431789
message_queue_size=self._message_queue_size,
1744-
max_message_size=self._max_message_size)
1790+
max_message_size=self._max_message_size,
1791+
receive_buffer_size=self._receive_buffer_size)
17451792
nursery.start_soon(connection._reader_task)
17461793
with trio.move_on_after(self._connect_timeout) as connect_scope:
17471794
request = await connection._get_request()

0 commit comments

Comments
 (0)