Skip to content

Commit 4265f4e

Browse files
committed
Ensure that identity provider is used for auth,
even in websockets (but only if those inherit from JupyterHandler, and if they do not fallback to previous implementation and warn).
1 parent 5e7615d commit 4265f4e

File tree

3 files changed

+78
-14
lines changed

3 files changed

+78
-14
lines changed

jupyter_server/base/handlers.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def check_host(self) -> bool:
589589
)
590590
return allow
591591

592-
async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
592+
async def prepare(self, *, _redirect_to_login=True) -> Awaitable[None] | None: # type:ignore[override]
593593
"""Prepare a response."""
594594
# Set the current Jupyter Handler context variable.
595595
CallContext.set(CallContext.JUPYTER_HANDLER, self)
@@ -636,9 +636,18 @@ async def prepare(self) -> Awaitable[None] | None: # type:ignore[override]
636636
raise HTTPError(403)
637637
method = getattr(self, self.request.method.lower())
638638
if not getattr(method, "__allow_unauthenticated", False):
639-
# reuse `web.authenticated` logic, which redirects to the login
640-
# page on GET and HEAD and otherwise raises 403
641-
return web.authenticated(lambda _: super().prepare)(self)
639+
if _redirect_to_login:
640+
# reuse `web.authenticated` logic, which redirects to the login
641+
# page on GET and HEAD and otherwise raises 403
642+
return web.authenticated(lambda _: super().prepare())(self)
643+
else:
644+
# raise 403 if user is not known without redirecting to login page
645+
user = self.current_user
646+
if user is None:
647+
self.log.warning(
648+
f"Couldn't authenticate {self.__class__.__name__} connection"
649+
)
650+
raise web.HTTPError(403)
642651

643652
return super().prepare()
644653

@@ -736,7 +745,7 @@ def write_error(self, status_code: int, **kwargs: Any) -> None:
736745
class APIHandler(JupyterHandler):
737746
"""Base class for API handlers"""
738747

739-
async def prepare(self) -> None:
748+
async def prepare(self) -> None: # type:ignore[override]
740749
"""Prepare an API response."""
741750
await super().prepare()
742751
if not self.check_origin():
@@ -848,7 +857,7 @@ def options(self, *args: Any, **kwargs: Any) -> None:
848857
class Template404(JupyterHandler):
849858
"""Render our 404 template"""
850859

851-
async def prepare(self) -> None:
860+
async def prepare(self) -> None: # type:ignore[override]
852861
"""Prepare a 404 response."""
853862
await super().prepare()
854863
raise web.HTTPError(404)

jupyter_server/base/websocket.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
"""Base websocket classes."""
22
import re
3+
import warnings
34
from typing import Optional, no_type_check
45
from urllib.parse import urlparse
56

67
from tornado import ioloop, web
78
from tornado.iostream import IOStream
89

10+
from jupyter_server.base.handlers import JupyterHandler
11+
from jupyter_server.utils import JupyterServerAuthWarning
12+
913
# ping interval for keeping websockets alive (30 seconds)
1014
WS_PING_INTERVAL = 30000
1115

@@ -84,7 +88,10 @@ def clear_cookie(self, *args, **kwargs):
8488

8589
@no_type_check
8690
def _maybe_auth(self):
87-
"""Verify authentication if required"""
91+
"""Verify authentication if required.
92+
93+
Only used when the websocket class does not inherit from JupyterHandler.
94+
"""
8895
if not self.settings.get("allow_unauthenticated_access", False):
8996
if not self.request.method:
9097
raise web.HTTPError(403)
@@ -100,8 +107,18 @@ def _maybe_auth(self):
100107
@no_type_check
101108
def prepare(self, *args, **kwargs):
102109
"""Handle a get request."""
103-
self._maybe_auth()
104-
return super().prepare(*args, **kwargs)
110+
if not isinstance(self, JupyterHandler):
111+
should_authenticate = not self.settings.get("allow_unauthenticated_access", False)
112+
if "identity_provider" in self.settings and should_authenticate:
113+
warnings.warn(
114+
"WebSocketMixin sub-class does not inherit from JupyterHandler"
115+
" preventing proper authentication using custom identity provider.",
116+
JupyterServerAuthWarning,
117+
stacklevel=2,
118+
)
119+
self._maybe_auth()
120+
return super().prepare(*args, **kwargs)
121+
return super().prepare(*args, **kwargs, _redirect_to_login=False)
105122

106123
@no_type_check
107124
def open(self, *args, **kwargs):

tests/base/test_websocket.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
from jupyter_server.auth import IdentityProvider, User
1313
from jupyter_server.auth.decorator import allow_unauthenticated
14+
from jupyter_server.base.handlers import JupyterHandler
1415
from jupyter_server.base.websocket import WebSocketMixin
1516
from jupyter_server.serverapp import ServerApp
16-
from jupyter_server.utils import url_path_join
17+
from jupyter_server.utils import JupyterServerAuthWarning, url_path_join
1718

1819

1920
class MockHandler(WebSocketMixin, WebSocketHandler):
@@ -66,11 +67,15 @@ async def test_ping_client_timeout(mixin):
6667
mixin.write_message("hello")
6768

6869

69-
class NoAuthRulesWebsocketHandler(MockHandler):
70+
class MockJupyterHandler(MockHandler, JupyterHandler):
7071
pass
7172

7273

73-
class PermissiveWebsocketHandler(MockHandler):
74+
class NoAuthRulesWebsocketHandler(MockJupyterHandler):
75+
pass
76+
77+
78+
class PermissiveWebsocketHandler(MockJupyterHandler):
7479
@allow_unauthenticated
7580
def get(self, *args, **kwargs) -> None:
7681
return super().get(*args, **kwargs)
@@ -148,5 +153,38 @@ def fetch():
148153
iidp = IndiscriminateIdentityProvider()
149154
# should allow access with the user set be the identity provider
150155
with patch.dict(jp_serverapp.web_app.settings, {"identity_provider": iidp}):
151-
res = await fetch()
152-
assert res.code == 200
156+
ws = await fetch()
157+
ws.close()
158+
159+
160+
class PermissivePlainWebsocketHandler(MockHandler):
161+
# note: inherits from MockHandler not MockJupyterHandler
162+
@allow_unauthenticated
163+
def get(self, *args, **kwargs) -> None:
164+
return super().get(*args, **kwargs)
165+
166+
167+
@pytest.mark.parametrize(
168+
"jp_server_config",
169+
[
170+
{
171+
"ServerApp": {
172+
"allow_unauthenticated_access": False,
173+
"identity_provider": IndiscriminateIdentityProvider(),
174+
}
175+
}
176+
],
177+
)
178+
async def test_websocket_auth_warns_mixin_lacks_jupyter_handler(jp_serverapp, jp_ws_fetch):
179+
app: ServerApp = jp_serverapp
180+
app.web_app.add_handlers(
181+
".*$",
182+
[(url_path_join(app.base_url, "permissive"), PermissivePlainWebsocketHandler)],
183+
)
184+
185+
with pytest.warns(
186+
JupyterServerAuthWarning,
187+
match="WebSocketMixin sub-class does not inherit from JupyterHandler",
188+
):
189+
ws = await jp_ws_fetch("permissive", headers={"Authorization": ""})
190+
ws.close()

0 commit comments

Comments
 (0)