@@ -200,6 +200,7 @@ async def mock_gateway_request(url, **kwargs):
200
200
201
201
202
202
mocked_gateway = patch ("jupyter_server.gateway.managers.gateway_request" , mock_gateway_request )
203
+ mock_gateway_ws_url = "ws://mock-gateway-server:8889"
203
204
mock_gateway_url = "http://mock-gateway-server:8889"
204
205
mock_http_user = "alice"
205
206
@@ -733,6 +734,41 @@ async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch,
733
734
pytest .fail (f"Logs contain an error: { message } " )
734
735
735
736
737
+ @patch ("tornado.websocket.websocket_connect" , mock_websocket_connect ())
738
+ async def test_websocket_connection_with_session_id (init_gateway , jp_serverapp , jp_fetch , caplog ):
739
+ # Create the session and kernel and get the kernel manager...
740
+ kernel_id = await create_kernel (jp_fetch , "kspec_foo" )
741
+ km : GatewayKernelManager = jp_serverapp .kernel_manager .get_kernel (kernel_id )
742
+
743
+ # Create the KernelWebsocketHandler...
744
+ request = HTTPServerRequest ("foo" , "GET" )
745
+ request .connection = MagicMock ()
746
+ handler = KernelWebsocketHandler (jp_serverapp .web_app , request )
747
+ # Create the GatewayWebSocketConnection and attach it to the handler...
748
+ with mocked_gateway :
749
+ conn = GatewayWebSocketConnection (parent = km , websocket_handler = handler )
750
+ handler .connection = conn
751
+ await conn .connect ()
752
+ assert conn .session_id != None
753
+ expected_ws_url = (
754
+ f"{ mock_gateway_ws_url } /api/kernels/{ kernel_id } /channels?session_id={ conn .session_id } "
755
+ )
756
+ assert (
757
+ expected_ws_url in caplog .text
758
+ ), "WebSocket URL does not contain the expected session_id."
759
+
760
+ # Processing websocket messages happens in separate coroutines and any
761
+ # errors in that process will show up in logs, but not bubble up to the
762
+ # caller.
763
+ #
764
+ # To check for these, we wait for the server to stop and then check the
765
+ # logs for errors.
766
+ await jp_serverapp ._cleanup ()
767
+ for _ , level , message in caplog .record_tuples :
768
+ if level >= logging .ERROR :
769
+ pytest .fail (f"Logs contain an error: { message } " )
770
+
771
+
736
772
#
737
773
# Test methods below...
738
774
#
0 commit comments