Skip to content

Commit 4cbb504

Browse files
committed
Implement auth and tests for websockets
`WebsocketMixing` may be used with or without `JupyterHandler`; - if used with it, we want to have custom auth implementation because redirecting to login page does not make sense for a websocket's GET request - if these are used without `JupyterHandler `we want the auth rules to still apply, even though the `current_user` logic may differ slightly
1 parent 646739e commit 4cbb504

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

jupyter_server/base/websocket.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, no_type_check
44
from urllib.parse import urlparse
55

6-
from tornado import ioloop
6+
from tornado import ioloop, web
77
from tornado.iostream import IOStream
88

99
# ping interval for keeping websockets alive (30 seconds)
@@ -82,6 +82,26 @@ def check_origin(self, origin: Optional[str] = None) -> bool:
8282
def clear_cookie(self, *args, **kwargs):
8383
"""meaningless for websockets"""
8484

85+
@no_type_check
86+
def _maybe_auth(self):
87+
"""Verify authentication if required"""
88+
if not self.settings.get("allow_unauthenticated_access", False):
89+
if not self.request.method:
90+
raise web.HTTPError(403)
91+
method = getattr(self, self.request.method.lower())
92+
if not getattr(method, "__allow_unauthenticated", False):
93+
# rather than re-using `web.authenticated` which also redirects
94+
# to login page on GET, just raise 403 if user is not known
95+
user = self.current_user
96+
if user is None:
97+
self.log.warning("Couldn't authenticate WebSocket connection")
98+
raise web.HTTPError(403)
99+
100+
def prepare(self, *args, **kwargs):
101+
"""Handle a get request."""
102+
self._maybe_auth()
103+
return super().prepare(*args, **kwargs)
104+
85105
@no_type_check
86106
def open(self, *args, **kwargs):
87107
"""Open the websocket."""

tests/base/test_websocket.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from unittest.mock import MagicMock
55

66
import pytest
7+
from tornado.httpclient import HTTPClientError
78
from tornado.httpserver import HTTPRequest
89
from tornado.httputil import HTTPHeaders
910
from tornado.websocket import WebSocketClosedError, WebSocketHandler
1011

12+
from jupyter_server.auth.decorator import allow_unauthenticated
1113
from jupyter_server.base.websocket import WebSocketMixin
1214
from jupyter_server.serverapp import ServerApp
15+
from jupyter_server.utils import url_path_join
1316

1417

1518
class MockHandler(WebSocketMixin, WebSocketHandler):
@@ -60,3 +63,58 @@ async def test_ping_client_timeout(mixin):
6063
mixin.send_ping()
6164
with pytest.raises(WebSocketClosedError):
6265
mixin.write_message("hello")
66+
67+
68+
class NoAuthRulesWebsocketHandler(MockHandler):
69+
pass
70+
71+
72+
class PermissiveWebsocketHandler(MockHandler):
73+
@allow_unauthenticated
74+
def get(self, *args, **kwargs) -> None:
75+
return super().get(*args, **kwargs)
76+
77+
78+
@pytest.mark.parametrize(
79+
"jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": True}}]
80+
)
81+
async def test_websocket_auth_permissive(jp_serverapp, jp_ws_fetch):
82+
app: ServerApp = jp_serverapp
83+
app.web_app.add_handlers(
84+
".*$",
85+
[
86+
(url_path_join(app.base_url, "no-rules"), NoAuthRulesWebsocketHandler),
87+
(url_path_join(app.base_url, "permissive"), PermissiveWebsocketHandler),
88+
],
89+
)
90+
91+
# should always permit access when `@allow_unauthenticated` is used
92+
ws = await jp_ws_fetch("permissive", headers={"Authorization": ""})
93+
ws.close()
94+
95+
# should allow access when no authentication rules are set up
96+
ws = await jp_ws_fetch("no-rules", headers={"Authorization": ""})
97+
ws.close()
98+
99+
100+
@pytest.mark.parametrize(
101+
"jp_server_config", [{"ServerApp": {"allow_unauthenticated_access": False}}]
102+
)
103+
async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch):
104+
app: ServerApp = jp_serverapp
105+
app.web_app.add_handlers(
106+
".*$",
107+
[
108+
(url_path_join(app.base_url, "no-rules"), NoAuthRulesWebsocketHandler),
109+
(url_path_join(app.base_url, "permissive"), PermissiveWebsocketHandler),
110+
],
111+
)
112+
113+
# should always permit access when `@allow_unauthenticated` is used
114+
ws = await jp_ws_fetch("permissive", headers={"Authorization": ""})
115+
ws.close()
116+
117+
# should forbid access when no authentication rules are set up
118+
with pytest.raises(HTTPClientError) as exception:
119+
ws = await jp_ws_fetch("no-rules", headers={"Authorization": ""})
120+
assert exception.value.code == 403

0 commit comments

Comments
 (0)