Skip to content

Commit b961d4e

Browse files
Pass session_id during Websocket connect (#1440)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0940671 commit b961d4e

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

Diff for: jupyter_server/gateway/connections.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ async def connect(self):
4747
url_escape(self.kernel_id),
4848
"channels",
4949
)
50+
if self.session_id:
51+
ws_url += f"?session_id={url_escape(self.session_id)}"
5052
self.log.info(f"Connecting to {ws_url}")
5153
kwargs: dict[str, Any] = {}
5254
kwargs = GatewayClient.instance().load_connection_args(**kwargs)

Diff for: tests/test_gateway.py

+36
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ async def mock_gateway_request(url, **kwargs):
200200

201201

202202
mocked_gateway = patch("jupyter_server.gateway.managers.gateway_request", mock_gateway_request)
203+
mock_gateway_ws_url = "ws://mock-gateway-server:8889"
203204
mock_gateway_url = "http://mock-gateway-server:8889"
204205
mock_http_user = "alice"
205206

@@ -733,6 +734,41 @@ async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch,
733734
pytest.fail(f"Logs contain an error: {message}")
734735

735736

737+
@patch("tornado.websocket.websocket_connect", mock_websocket_connect())
738+
async def test_websocket_connection_with_session_id(init_gateway, jp_serverapp, jp_fetch, caplog):
739+
# Create the session and kernel and get the kernel manager...
740+
kernel_id = await create_kernel(jp_fetch, "kspec_foo")
741+
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)
742+
743+
# Create the KernelWebsocketHandler...
744+
request = HTTPServerRequest("foo", "GET")
745+
request.connection = MagicMock()
746+
handler = KernelWebsocketHandler(jp_serverapp.web_app, request)
747+
# Create the GatewayWebSocketConnection and attach it to the handler...
748+
with mocked_gateway:
749+
conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler)
750+
handler.connection = conn
751+
await conn.connect()
752+
assert conn.session_id != None
753+
expected_ws_url = (
754+
f"{mock_gateway_ws_url}/api/kernels/{kernel_id}/channels?session_id={conn.session_id}"
755+
)
756+
assert (
757+
expected_ws_url in caplog.text
758+
), "WebSocket URL does not contain the expected session_id."
759+
760+
# Processing websocket messages happens in separate coroutines and any
761+
# errors in that process will show up in logs, but not bubble up to the
762+
# caller.
763+
#
764+
# To check for these, we wait for the server to stop and then check the
765+
# logs for errors.
766+
await jp_serverapp._cleanup()
767+
for _, level, message in caplog.record_tuples:
768+
if level >= logging.ERROR:
769+
pytest.fail(f"Logs contain an error: {message}")
770+
771+
736772
#
737773
# Test methods below...
738774
#

0 commit comments

Comments
 (0)