From fd31b2abc51f9b28be2481df7055162dda20f826 Mon Sep 17 00:00:00 2001 From: darkhaniop <72267237+darkhaniop@users.noreply.github.com> Date: Sun, 6 Apr 2025 10:55:52 +0900 Subject: [PATCH 1/3] Sync connection tracking and graceful shutdown * Implement sync server connection tracking. * Add ServerConnection.close() call for exising connections on server shutdown. This is useful for cleanly terminating/restarting the server process. Issue #1488 --- src/websockets/sync/server.py | 56 +++++++++++++++++++++++++++-- tests/sync/test_server.py | 67 +++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index efb40a7f4..e766de493 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -238,12 +238,20 @@ def __init__( socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, + *, + connections: set[ServerConnection] | None = None, ) -> None: self.socket = socket self.handler = handler if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger + + # _connections tracks active connections + if connections is None: + connections = set() + self._connections = connections + if sys.platform != "win32": self.shutdown_watcher, self.shutdown_notifier = os.pipe() @@ -285,15 +293,36 @@ def serve_forever(self) -> None: thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() - def shutdown(self) -> None: + def shutdown( + self, *, code: CloseCode = CloseCode.NORMAL_CLOSURE, reason: str = "" + ) -> None: """ See :meth:`socketserver.BaseServer.shutdown`. + Shuts down the server and closes existing connections. Optional arguments + ``code`` and ``reason`` can be used to provide additional information to + the clients, e.g.,:: + + server.shutdown(reason="scheduled_maintenance") + + Args: + code: Closing code, defaults to ``CloseCode.NORMAL_CLOSURE``. + reason: Closing reason, default to empty string. + """ self.socket.close() if sys.platform != "win32": os.write(self.shutdown_notifier, b"x") + # Close all connections + conns = list(self._connections) + for conn in conns: + try: + conn.close(code=code, reason=reason) + except Exception as exc: + debug_msg = f"Could not close {conn.id}: {exc}" + self.logger.debug(debug_msg, exc_info=exc) + def fileno(self) -> int: """ See :meth:`socketserver.BaseServer.fileno`. @@ -516,6 +545,24 @@ def handler(websocket): do_handshake_on_connect=False, ) + # Stores active ServerConnection instances, used by the server to handle graceful + # shutdown in Server.shutdown() + connections: set[ServerConnection] = set() + + def on_connection_created(connection: ServerConnection) -> None: + # Invoked from conn_handler() to add a new ServerConnection instance to + # Server._connections + connections.add(connection) + + def on_connection_closed(connection: ServerConnection) -> None: + # Invoked from conn_handler() to remove a closed ServerConnection instance from + # Server._connections. Keeping only active references in the set is important + # for avoiding memory leaks. + try: + connections.remove(connection) + except KeyError: # pragma: no cover + pass + # Define request handler def conn_handler(sock: socket.socket, addr: Any) -> None: @@ -581,6 +628,7 @@ def protocol_select_subprotocol( close_timeout=close_timeout, max_queue=max_queue, ) + on_connection_created(connection) except Exception: sock.close() return @@ -595,11 +643,13 @@ def protocol_select_subprotocol( ) except TimeoutError: connection.close_socket() + on_connection_closed(connection) connection.recv_events_thread.join() return except Exception: connection.logger.error("opening handshake failed", exc_info=True) connection.close_socket() + on_connection_closed(connection) connection.recv_events_thread.join() return @@ -610,8 +660,10 @@ def protocol_select_subprotocol( except Exception: connection.logger.error("connection handler failed", exc_info=True) connection.close(CloseCode.INTERNAL_ERROR) + on_connection_closed(connection) else: connection.close() + on_connection_closed(connection) except Exception: # pragma: no cover # Don't leak sockets on unexpected errors. @@ -619,7 +671,7 @@ def protocol_select_subprotocol( # Initialize server - return Server(sock, conn_handler, logger) + return Server(sock, conn_handler, logger, connections=connections) def unix_serve( diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d04d1859a..13543b666 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -3,6 +3,7 @@ import http import logging import socket +import threading import time import unittest @@ -12,6 +13,7 @@ InvalidStatus, NegotiationError, ) +from websockets import CloseCode, State from websockets.http11 import Request, Response from websockets.sync.client import connect, unix_connect from websockets.sync.server import * @@ -338,6 +340,71 @@ def test_junk_handshake(self): ["invalid HTTP request line: HELO relay.invalid"], ) + def test_initialize_server_without_tracking_connections(self): + """Call Server() constructor without 'connections' arg.""" + with socket.create_server(("localhost", 0)) as sock: + server = Server(socket=sock, handler=handler) + self.assertIsInstance( + server._connections, set, "Server._connections property not initialized" + ) + + def test_connections_is_empty_after_disconnects(self): + """Clients are added to Server._connections, and removed when disconnected.""" + with run_server() as server: + connections: set[ServerConnection] = server._connections + with connect(get_uri(server)) as client: + self.assertEqual(len(connections), 1) + time.sleep(0.5) + self.assertEqual(len(connections), 0) + + def test_shutdown_calls_close_for_all_connections(self): + """Graceful shutdown with broken ServerConnection.close() implementations.""" + CLIENTS_TO_LAUNCH = 3 + + connections_attempted = 0 + + class ServerConnectionWithBrokenClose(ServerConnection): + close_method_called = False + + def close(self, code=CloseCode.NORMAL_CLOSURE, reason=""): + """Custom close method that intentionally fails.""" + + # Do not increment the counter when calling .close() multiple times + if self.close_method_called: + return + self.close_method_called = True + + nonlocal connections_attempted + connections_attempted += 1 + raise Exception("broken close method") + + clients: set[threading.Thread] = set() + with run_server(create_connection=ServerConnectionWithBrokenClose) as server: + + def client(): + with connect(get_uri(server)) as client: + time.sleep(1) + + for i in range(CLIENTS_TO_LAUNCH): + client_thread = threading.Thread(target=client) + client_thread.start() + clients.add(client_thread) + time.sleep(0.2) + self.assertEqual( + len(server._connections), + CLIENTS_TO_LAUNCH, + "not all clients connected to the server yet, increase sleep duration", + ) + server.shutdown() + while len(clients) > 0: + client = clients.pop() + client.join() + self.assertEqual( + connections_attempted, + CLIENTS_TO_LAUNCH, + "server did not call ServerConnection.close() on all connections", + ) + class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): From 663b561d94e48e4fae53b5f4ffe3424028df2cc3 Mon Sep 17 00:00:00 2001 From: darkhaniop <72267237+darkhaniop@users.noreply.github.com> Date: Sun, 6 Apr 2025 21:01:55 +0900 Subject: [PATCH 2/3] Reorder imports to address code quality (tox) Running tox revealed that the "Sync connection tracking and graceful shutdown" patch introduced "misordered import statements" warning and unused assignments (client) in "with connect(...) as client:" in the added tests. This commit addresses these code quality messages. --- tests/sync/test_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 13543b666..53b3eed14 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -7,13 +7,13 @@ import time import unittest +from websockets import CloseCode from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, InvalidStatus, NegotiationError, ) -from websockets import CloseCode, State from websockets.http11 import Request, Response from websockets.sync.client import connect, unix_connect from websockets.sync.server import * @@ -352,7 +352,7 @@ def test_connections_is_empty_after_disconnects(self): """Clients are added to Server._connections, and removed when disconnected.""" with run_server() as server: connections: set[ServerConnection] = server._connections - with connect(get_uri(server)) as client: + with connect(get_uri(server)): self.assertEqual(len(connections), 1) time.sleep(0.5) self.assertEqual(len(connections), 0) @@ -382,7 +382,7 @@ def close(self, code=CloseCode.NORMAL_CLOSURE, reason=""): with run_server(create_connection=ServerConnectionWithBrokenClose) as server: def client(): - with connect(get_uri(server)) as client: + with connect(get_uri(server)): time.sleep(1) for i in range(CLIENTS_TO_LAUNCH): From f1f409fb7b46b2d4ba5dcb498de254a3a754de2b Mon Sep 17 00:00:00 2001 From: darkhaniop <72267237+darkhaniop@users.noreply.github.com> Date: Mon, 7 Apr 2025 09:57:03 +0900 Subject: [PATCH 3/3] Revert sync.Server.shutdown() signature Also, update docstrings to explain the added `connections` arg in the `sync.Server()` constructor. PR #1615 --- src/websockets/sync/server.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index e766de493..4c8373930 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -230,6 +230,10 @@ class Server: logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. + connections: A set of open :class:`ServerConnection` instances + maintained by `handler`. When omitted, e.g., if the handler does + not maintain such a set, this defaults to an empty set and the + server will not attempt to close connections on shutdown. """ @@ -293,21 +297,11 @@ def serve_forever(self) -> None: thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() - def shutdown( - self, *, code: CloseCode = CloseCode.NORMAL_CLOSURE, reason: str = "" - ) -> None: + def shutdown(self) -> None: """ See :meth:`socketserver.BaseServer.shutdown`. - Shuts down the server and closes existing connections. Optional arguments - ``code`` and ``reason`` can be used to provide additional information to - the clients, e.g.,:: - - server.shutdown(reason="scheduled_maintenance") - - Args: - code: Closing code, defaults to ``CloseCode.NORMAL_CLOSURE``. - reason: Closing reason, default to empty string. + Shuts down the server and closes existing connections. """ self.socket.close() @@ -318,7 +312,7 @@ def shutdown( conns = list(self._connections) for conn in conns: try: - conn.close(code=code, reason=reason) + conn.close() except Exception as exc: debug_msg = f"Could not close {conn.id}: {exc}" self.logger.debug(debug_msg, exc_info=exc)