73
73
from resotocore .task .subscribers import SubscriptionHandler
74
74
from resotocore .task .task_handler import TaskHandlerService
75
75
from resotocore .types import Json , JsonElement
76
- from resotocore .util import uuid_str , force_gen , rnd_str , if_set , duration , utc_str , parse_utc
76
+ from resotocore .util import uuid_str , force_gen , rnd_str , if_set , duration , utc_str , parse_utc , async_noop
77
77
from resotocore .web .certificate_handler import CertificateHandler
78
78
from resotocore .web .content_renderer import result_binary_gen , single_result
79
79
from resotocore .web .directives import (
92
92
WorkerTaskResult ,
93
93
WorkerTaskInProgress ,
94
94
)
95
- from resotolib .asynchronous .web .auth import auth_handler
95
+ from resotolib .asynchronous .web .auth import auth_handler , set_valid_jwt , raw_jwt_from_auth_message
96
96
from resotolib .asynchronous .web .ws_handler import accept_websocket , clean_ws_handler
97
97
from resotolib .jwt import encode_jwt
98
98
@@ -106,7 +106,10 @@ def section_of(request: Request) -> Optional[str]:
106
106
return section
107
107
108
108
109
+ # No Authorization required for following paths
109
110
AlwaysAllowed = {"/" , "/metrics" , "/api-doc.*" , "/system/.*" , "/ui.*" , "/ca/cert" , "/notebook.*" }
111
+ # Authorization is not required, but implemented as part of the request handler
112
+ DeferredCheck = {"/events" }
110
113
111
114
112
115
class Api :
@@ -141,7 +144,7 @@ def __init__(
141
144
# note on order: the middleware is passed in the order provided.
142
145
middlewares = [
143
146
metrics_handler ,
144
- auth_handler (config .args .psk , AlwaysAllowed ),
147
+ auth_handler (config .args .psk , AlwaysAllowed | DeferredCheck ),
145
148
cors_handler ,
146
149
error_handler (config , event_sender ),
147
150
default_middleware (self ),
@@ -222,7 +225,6 @@ def __add_routes(self, prefix: str) -> None:
222
225
web .post (prefix + "/analytics" , self .send_analytics_events ),
223
226
# Worker operations
224
227
web .get (prefix + "/work/queue" , self .handle_work_tasks ),
225
- web .get (prefix + "/work/create" , self .create_work ),
226
228
web .get (prefix + "/work/list" , self .list_work ),
227
229
# Serve static filed
228
230
web .get (prefix , self .forward ("/ui/index.html" )),
@@ -456,6 +458,15 @@ async def listen_to_events(
456
458
event_types : List [str ],
457
459
initial_messages : Optional [Sequence [Message ]] = None ,
458
460
) -> WebSocketResponse :
461
+ handler : Callable [[str ], Awaitable [None ]] = async_noop
462
+
463
+ async def authorize_request (msg : str ) -> None :
464
+ nonlocal handler
465
+ if (r := raw_jwt_from_auth_message (msg )) and set_valid_jwt (request , r , self .config .args .psk ) is not None :
466
+ handler = handle_message
467
+ else :
468
+ raise ValueError ("No Authorization header provided and no valid auth message sent" )
469
+
459
470
async def handle_message (msg : str ) -> None :
460
471
js = json .loads (msg )
461
472
if "data" in js :
@@ -475,32 +486,38 @@ async def handle_message(msg: str) -> None:
475
486
else :
476
487
await self .message_bus .emit (message )
477
488
489
+ handler = authorize_request if request .get ("authorized" , False ) is False else handle_message
478
490
return await accept_websocket (
479
491
request ,
480
- handle_incoming = handle_message ,
492
+ handle_incoming = lambda x : handler ( x ), # pylint: disable=unnecessary-lambda # it is required!
481
493
outgoing_context = partial (self .message_bus .subscribe , listener_id , event_types ),
482
494
websocket_handler = self .websocket_handler ,
483
495
initial_messages = initial_messages ,
484
496
)
485
497
486
498
async def handle_work_tasks (self , request : Request ) -> WebSocketResponse :
487
499
worker_id = WorkerId (uuid_str ())
488
- initialized = False
489
500
worker_descriptions : Future [List [WorkerTaskDescription ]] = asyncio .get_event_loop ().create_future ()
501
+ handler : Callable [[str ], Awaitable [None ]] = async_noop
502
+
503
+ async def authorize_request (msg : str ) -> None :
504
+ nonlocal handler
505
+ if (r := raw_jwt_from_auth_message (msg )) and set_valid_jwt (request , r , self .config .args .psk ) is not None :
506
+ handler = handle_connect
507
+ else :
508
+ raise ValueError ("No Authorization header provided and no valid auth message sent" )
490
509
491
510
async def handle_connect (msg : str ) -> None :
492
- nonlocal initialized
511
+ nonlocal handler
493
512
cmds = from_js (json .loads (msg ), List [WorkerCustomCommand ])
494
- print ("connected: " , cmds )
495
-
496
513
description = [WorkerTaskDescription (cmd .name , cmd .filter ) for cmd in cmds ]
497
514
# set the future and allow attaching the worker to the task queue
498
515
worker_descriptions .set_result (description )
499
516
# register the descriptions as custom command on the CLI
500
517
for cmd in cmds :
501
518
self .cli .register_worker_custom_command (cmd )
502
- # mark the worker as initialized
503
- initialized = True
519
+ # the connect process is done, define the final handler
520
+ handler = handle_message
504
521
505
522
async def handle_message (msg : str ) -> None :
506
523
tr = from_js (json .loads (msg ), WorkerTaskResult )
@@ -523,26 +540,17 @@ async def connect_to_task_queue() -> AsyncIterator[Queue[WorkerTask]]:
523
540
async with self .worker_task_queue .attach (worker_id , descriptions ) as queue :
524
541
yield queue
525
542
543
+ handler = authorize_request if request .get ("authorized" , False ) is False else handle_connect
526
544
# noinspection PyTypeChecker
527
545
return await accept_websocket (
528
546
request ,
529
- handle_incoming = lambda msg : handle_connect ( msg ) if not initialized else handle_message ( msg ),
547
+ handle_incoming = lambda x : handler ( x ), # pylint: disable=unnecessary-lambda # it is required!
530
548
outgoing_context = connect_to_task_queue ,
531
549
websocket_handler = self .websocket_handler ,
532
550
outgoing_fn = task_json ,
533
551
)
534
552
535
- async def create_work (self , request : Request ) -> StreamResponse :
536
- attrs = {k : v for k , v in request .query .items () if k != "task" }
537
- future = asyncio .get_event_loop ().create_future ()
538
- task = WorkerTask (
539
- TaskId (uuid_str ()), "test" , attrs , {"some" : "data" , "foo" : "bla" }, future , timedelta (seconds = 3 )
540
- )
541
- await self .worker_task_queue .add_task (task )
542
- await future
543
- return web .HTTPOk ()
544
-
545
- async def list_work (self , request : Request ) -> StreamResponse :
553
+ async def list_work (self , _ : Request ) -> StreamResponse :
546
554
def wt_to_js (ip : WorkerTaskInProgress ) -> Json :
547
555
return {
548
556
"task" : ip .task .to_json (),
@@ -870,7 +878,7 @@ async def execute(self, request: Request) -> StreamResponse:
870
878
temp = tempfile .mkdtemp ()
871
879
temp_dir = temp
872
880
files = {}
873
- # for now we assume that all multi-parts are file uploads
881
+ # for now, we assume that all multi-parts are file uploads
874
882
async for part in MultipartReader (request .headers , request .content ):
875
883
name = part .name
876
884
if not name :
0 commit comments