Skip to content

Commit e044fa8

Browse files
authored
[resotocore][feat] Allow authorization message as first ws message (#1308)
* [resotocore][feat] Allow authorization message as first ws message * define handler as variable and avoid lookup * oops * add comment * define separate groups to make the intent more clear
1 parent 589b0d2 commit e044fa8

File tree

5 files changed

+108
-48
lines changed

5 files changed

+108
-48
lines changed

resotocore/resotocore/static/api-doc.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1441,12 +1441,19 @@ paths:
14411441
/events:
14421442
get:
14431443
summary: "[WebSocket] Register as event listener and receive all events."
1444-
description:
1444+
description: |
14451445
## WebSocket Endpoint
14461446
The client needs to send all the required headers for a ws connection
14471447
and has to handle the websocket protocol.<br/>
14481448
**Note this can not be tested from within swagger!**
14491449
1450+
## Authorization
1451+
In case Resoto has a PSK infrastructure in place, the client needs to send a JWT token via the `Authorization` header
1452+
or via the `resoto_authorization` cookie.
1453+
It is also possible to omit header or cookie and send an Authorization message as first message on the websocket.
1454+
Example
1455+
{ "kind": "authorization", "jwt": "Bearer <jwt>" }
1456+
14501457
parameters:
14511458
- name: show
14521459
in: query

resotocore/resotocore/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
Periodic = periodic.Periodic
4646

4747

48+
# noinspection PyUnusedLocal
49+
async def async_noop(*args: Any, **kwargs: Any) -> None:
50+
pass
51+
52+
4853
def identity(o: AnyT) -> AnyT:
4954
return o
5055

resotocore/resotocore/web/api.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from resotocore.task.subscribers import SubscriptionHandler
7474
from resotocore.task.task_handler import TaskHandlerService
7575
from resotocore.types import Json, JsonElement
76-
from resotocore.util import uuid_str, force_gen, rnd_str, if_set, duration, utc_str, parse_utc
76+
from resotocore.util import uuid_str, force_gen, rnd_str, if_set, duration, utc_str, parse_utc, async_noop
7777
from resotocore.web.certificate_handler import CertificateHandler
7878
from resotocore.web.content_renderer import result_binary_gen, single_result
7979
from resotocore.web.directives import (
@@ -92,7 +92,7 @@
9292
WorkerTaskResult,
9393
WorkerTaskInProgress,
9494
)
95-
from resotolib.asynchronous.web.auth import auth_handler
95+
from resotolib.asynchronous.web.auth import auth_handler, set_valid_jwt, raw_jwt_from_auth_message
9696
from resotolib.asynchronous.web.ws_handler import accept_websocket, clean_ws_handler
9797
from resotolib.jwt import encode_jwt
9898

@@ -106,7 +106,10 @@ def section_of(request: Request) -> Optional[str]:
106106
return section
107107

108108

109+
# No Authorization required for following paths
109110
AlwaysAllowed = {"/", "/metrics", "/api-doc.*", "/system/.*", "/ui.*", "/ca/cert", "/notebook.*"}
111+
# Authorization is not required, but implemented as part of the request handler
112+
DeferredCheck = {"/events"}
110113

111114

112115
class Api:
@@ -141,7 +144,7 @@ def __init__(
141144
# note on order: the middleware is passed in the order provided.
142145
middlewares=[
143146
metrics_handler,
144-
auth_handler(config.args.psk, AlwaysAllowed),
147+
auth_handler(config.args.psk, AlwaysAllowed | DeferredCheck),
145148
cors_handler,
146149
error_handler(config, event_sender),
147150
default_middleware(self),
@@ -222,7 +225,6 @@ def __add_routes(self, prefix: str) -> None:
222225
web.post(prefix + "/analytics", self.send_analytics_events),
223226
# Worker operations
224227
web.get(prefix + "/work/queue", self.handle_work_tasks),
225-
web.get(prefix + "/work/create", self.create_work),
226228
web.get(prefix + "/work/list", self.list_work),
227229
# Serve static filed
228230
web.get(prefix, self.forward("/ui/index.html")),
@@ -456,6 +458,15 @@ async def listen_to_events(
456458
event_types: List[str],
457459
initial_messages: Optional[Sequence[Message]] = None,
458460
) -> WebSocketResponse:
461+
handler: Callable[[str], Awaitable[None]] = async_noop
462+
463+
async def authorize_request(msg: str) -> None:
464+
nonlocal handler
465+
if (r := raw_jwt_from_auth_message(msg)) and set_valid_jwt(request, r, self.config.args.psk) is not None:
466+
handler = handle_message
467+
else:
468+
raise ValueError("No Authorization header provided and no valid auth message sent")
469+
459470
async def handle_message(msg: str) -> None:
460471
js = json.loads(msg)
461472
if "data" in js:
@@ -475,32 +486,38 @@ async def handle_message(msg: str) -> None:
475486
else:
476487
await self.message_bus.emit(message)
477488

489+
handler = authorize_request if request.get("authorized", False) is False else handle_message
478490
return await accept_websocket(
479491
request,
480-
handle_incoming=handle_message,
492+
handle_incoming=lambda x: handler(x), # pylint: disable=unnecessary-lambda # it is required!
481493
outgoing_context=partial(self.message_bus.subscribe, listener_id, event_types),
482494
websocket_handler=self.websocket_handler,
483495
initial_messages=initial_messages,
484496
)
485497

486498
async def handle_work_tasks(self, request: Request) -> WebSocketResponse:
487499
worker_id = WorkerId(uuid_str())
488-
initialized = False
489500
worker_descriptions: Future[List[WorkerTaskDescription]] = asyncio.get_event_loop().create_future()
501+
handler: Callable[[str], Awaitable[None]] = async_noop
502+
503+
async def authorize_request(msg: str) -> None:
504+
nonlocal handler
505+
if (r := raw_jwt_from_auth_message(msg)) and set_valid_jwt(request, r, self.config.args.psk) is not None:
506+
handler = handle_connect
507+
else:
508+
raise ValueError("No Authorization header provided and no valid auth message sent")
490509

491510
async def handle_connect(msg: str) -> None:
492-
nonlocal initialized
511+
nonlocal handler
493512
cmds = from_js(json.loads(msg), List[WorkerCustomCommand])
494-
print("connected: ", cmds)
495-
496513
description = [WorkerTaskDescription(cmd.name, cmd.filter) for cmd in cmds]
497514
# set the future and allow attaching the worker to the task queue
498515
worker_descriptions.set_result(description)
499516
# register the descriptions as custom command on the CLI
500517
for cmd in cmds:
501518
self.cli.register_worker_custom_command(cmd)
502-
# mark the worker as initialized
503-
initialized = True
519+
# the connect process is done, define the final handler
520+
handler = handle_message
504521

505522
async def handle_message(msg: str) -> None:
506523
tr = from_js(json.loads(msg), WorkerTaskResult)
@@ -523,26 +540,17 @@ async def connect_to_task_queue() -> AsyncIterator[Queue[WorkerTask]]:
523540
async with self.worker_task_queue.attach(worker_id, descriptions) as queue:
524541
yield queue
525542

543+
handler = authorize_request if request.get("authorized", False) is False else handle_connect
526544
# noinspection PyTypeChecker
527545
return await accept_websocket(
528546
request,
529-
handle_incoming=lambda msg: handle_connect(msg) if not initialized else handle_message(msg),
547+
handle_incoming=lambda x: handler(x), # pylint: disable=unnecessary-lambda # it is required!
530548
outgoing_context=connect_to_task_queue,
531549
websocket_handler=self.websocket_handler,
532550
outgoing_fn=task_json,
533551
)
534552

535-
async def create_work(self, request: Request) -> StreamResponse:
536-
attrs = {k: v for k, v in request.query.items() if k != "task"}
537-
future = asyncio.get_event_loop().create_future()
538-
task = WorkerTask(
539-
TaskId(uuid_str()), "test", attrs, {"some": "data", "foo": "bla"}, future, timedelta(seconds=3)
540-
)
541-
await self.worker_task_queue.add_task(task)
542-
await future
543-
return web.HTTPOk()
544-
545-
async def list_work(self, request: Request) -> StreamResponse:
553+
async def list_work(self, _: Request) -> StreamResponse:
546554
def wt_to_js(ip: WorkerTaskInProgress) -> Json:
547555
return {
548556
"task": ip.task.to_json(),
@@ -870,7 +878,7 @@ async def execute(self, request: Request) -> StreamResponse:
870878
temp = tempfile.mkdtemp()
871879
temp_dir = temp
872880
files = {}
873-
# for now we assume that all multi-parts are file uploads
881+
# for now, we assume that all multi-parts are file uploads
874882
async for part in MultipartReader(request.headers, request.content):
875883
name = part.name
876884
if not name:

resotolib/resotolib/asynchronous/web/auth.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import logging
23
import re
34
from contextvars import ContextVar
@@ -25,11 +26,39 @@ async def jwt_from_context() -> JWT:
2526
return __JWT_Context.get()
2627

2728

29+
def raw_jwt_from_auth_message(msg: str) -> Optional[str]:
30+
"""
31+
Expected message: json object with type kind="authorization" and a jwt field
32+
{ "kind": "authorization", "jwt": "Bearer <jwt>" }
33+
"""
34+
try:
35+
js = json.loads(msg)
36+
assert js.get("kind") == "authorization"
37+
return js.get("jwt")
38+
except Exception:
39+
return None
40+
41+
2842
@middleware
2943
async def no_check(request: Request, handler: RequestHandler) -> StreamResponse:
44+
# all requests are authorized automatically
45+
request["authorized"] = True
3046
return await handler(request)
3147

3248

49+
def set_valid_jwt(request: Request, jwt_raw: str, psk: str) -> Optional[JWT]:
50+
try:
51+
# note: the expiration is already checked by this function
52+
jwt = ck_jwt.decode_jwt_from_header_value(jwt_raw, psk)
53+
except PyJWTError:
54+
return None
55+
if jwt:
56+
request["jwt"] = jwt
57+
request["authorized"] = True
58+
__JWT_Context.set(jwt)
59+
return jwt
60+
61+
3362
def check_jwt(psk: str, always_allowed_paths: Set[str]) -> Middleware:
3463
def always_allowed(request: Request) -> bool:
3564
for path in always_allowed_paths:
@@ -40,9 +69,9 @@ def always_allowed(request: Request) -> bool:
4069
@middleware
4170
async def valid_jwt_handler(request: Request, handler: RequestHandler) -> StreamResponse:
4271
auth_header = request.headers.get("Authorization") or request.cookies.get("resoto_authorization")
43-
if always_allowed(request):
44-
return await handler(request)
45-
elif auth_header:
72+
authorized = False
73+
if auth_header:
74+
# make sure origin and host match, so the request is valid
4675
origin: Optional[str] = urlparse(request.headers.get("Origin")).hostname
4776
host: Optional[str] = request.headers.get("Host")
4877
if host is not None and origin is not None:
@@ -51,16 +80,13 @@ async def valid_jwt_handler(request: Request, handler: RequestHandler) -> Stream
5180
if origin.lower() != host.lower():
5281
log.warning(f"Origin {origin} is not allowed in request from {request.remote} to {request.path}")
5382
raise web.HTTPForbidden()
54-
try:
55-
# note: the expiration is already checked by this function
56-
jwt = ck_jwt.decode_jwt_from_header_value(auth_header, psk)
57-
except PyJWTError as ex:
58-
raise web.HTTPUnauthorized() from ex
59-
if jwt:
60-
__JWT_Context.set(jwt)
61-
return await handler(request)
62-
# if we come here, something is wrong: reject
63-
raise web.HTTPUnauthorized()
83+
84+
# try to authorize the request, even if it is one of the always allowed paths
85+
authorized = set_valid_jwt(request, auth_header, psk) is not None
86+
if authorized or always_allowed(request):
87+
return await handler(request)
88+
else:
89+
raise web.HTTPUnauthorized()
6490

6591
return valid_jwt_handler
6692

resotolib/resotolib/asynchronous/web/ws_handler.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ async def accept_websocket(
5353
await ws.prepare(request)
5454
wsid = str(uuid1())
5555

56+
# in case we wait for an initial authorization message, only wait for a limited amount of tine
57+
async def wait_for_authorization() -> None:
58+
counter = 10
59+
while request.get("authorized", False) is not True and counter >= 0:
60+
await asyncio.sleep(1)
61+
counter -= 1
62+
if counter <= 0:
63+
log.info(f"Wait for authorization: message listener {wsid}: Timeout. Hang up.")
64+
await clean_ws_handler(wsid, websocket_handler)
65+
5666
async def receive() -> None:
5767
try:
5868
async for msg in ws:
@@ -73,6 +83,14 @@ async def receive() -> None:
7383

7484
async def send(ctx: Callable[[], AsyncContextManager[Queue[T]]]) -> None:
7585
try:
86+
# wait for the request to become authorized, before we will send any message
87+
while request.get("authorized", False) is not True:
88+
await asyncio.sleep(1)
89+
# send all initial messages
90+
if initial_messages:
91+
for msg in initial_messages:
92+
await ws.send_str(outgoing_fn(msg) + "\n")
93+
# attach to the queue and wait for messages
7694
async with ctx() as events:
7795
while True:
7896
event = await events.get()
@@ -83,17 +101,13 @@ async def send(ctx: Callable[[], AsyncContextManager[Queue[T]]]) -> None:
83101
finally:
84102
await clean_ws_handler(wsid, websocket_handler)
85103

86-
receive_task = asyncio.create_task(receive())
87-
to_wait = (
88-
asyncio.gather(receive_task, asyncio.create_task(send(outgoing_context)))
89-
if outgoing_context is not None
90-
else receive_task
91-
)
92-
93-
if initial_messages:
94-
for msg in initial_messages:
95-
await ws.send_str(outgoing_fn(msg) + "\n")
104+
tasks = [asyncio.create_task(receive())]
105+
if outgoing_context is not None:
106+
tasks.append(asyncio.create_task(send(outgoing_context)))
107+
if request.get("authorized", False) is not True:
108+
tasks.append(asyncio.create_task(wait_for_authorization()))
96109

110+
to_wait = asyncio.gather(*tasks)
97111
websocket_handler[wsid] = (to_wait, ws)
98112
await to_wait
99113
return ws

0 commit comments

Comments
 (0)