@@ -113,7 +113,6 @@ class Session:
113
113
_messages (List[ApiBaseMessage]): A list to store messages associated with the session.
114
114
_message_ids (set[str]): A set to store unique message IDs to prevent duplication.
115
115
"""
116
-
117
116
def __init__ (self , api_base_url : str , api_key : str , session_id : str , function_group : str ):
118
117
"""
119
118
Initializes a new session with the given parameters.
@@ -274,8 +273,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
274
273
agent_json : dict = message ['agent_json' ]
275
274
agent_json_type : str = agent_json ['type' ].upper ()
276
275
if is_task_selection_message (message_type = agent_json_type ):
277
- indexed_task_options : dict = get_indexed_task_options_from_raw_api_response (
278
- raw_api_response = message )
276
+ indexed_task_options : dict = get_indexed_task_options_from_raw_api_response (raw_api_response = message )
279
277
newMessages .append (
280
278
TaskSelectionMessage .model_validate ({
281
279
'type' : agent_json_type ,
@@ -363,17 +361,14 @@ async def execute_function(self, function_ids: list[str], objective: str, contex
363
361
})
364
362
)
365
363
366
-
367
364
class AiEngine :
368
365
def __init__ (self , api_key : str , options : Optional [dict ] = None ):
369
- self ._api_base_url = options .get (
370
- 'api_base_url' ) if options and 'api_base_url' in options else default_api_base_url
366
+ self ._api_base_url = options .get ('api_base_url' ) if options and 'api_base_url' in options else default_api_base_url
371
367
self ._api_key = api_key
372
368
373
369
####
374
370
# Function groups
375
371
####
376
-
377
372
async def get_function_groups (self ) -> List [FunctionGroup ]:
378
373
logger .debug ("get_function_groups" )
379
374
publicGroups , privateGroups = await asyncio .gather (
@@ -466,7 +461,6 @@ async def get_function_group_by_function(self, function_id: str):
466
461
###
467
462
# Functions
468
463
###
469
-
470
464
async def get_functions_by_function_group (self , function_group_id : str ) -> list [FunctionGroupFunctions ]:
471
465
raw_response : dict = await make_api_request (
472
466
api_base_url = self ._api_base_url ,
@@ -478,14 +472,14 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
478
472
if "functions" in raw_response :
479
473
list (
480
474
map (
481
- lambda function_name : FunctionGroupFunctions .model_validate (
482
- {"name" : function_name }),
475
+ lambda function_name : FunctionGroupFunctions .model_validate ({"name" : function_name }),
483
476
raw_response ["functions" ]
484
477
)
485
478
)
486
479
487
480
return result
488
481
482
+
489
483
async def get_functions (self ) -> list [Function ]:
490
484
raw_response : dict = await make_api_request (
491
485
api_base_url = self ._api_base_url ,
@@ -502,10 +496,8 @@ async def get_functions(self) -> list[Function]:
502
496
####
503
497
# Model
504
498
####
505
-
506
499
async def get_models (self ) -> List [Model ]:
507
- pending_credits = [self .get_model_credits (
508
- model_id ) for model_id in DefaultModelIds ]
500
+ pending_credits = [self .get_model_credits (model_id ) for model_id in DefaultModelIds ]
509
501
510
502
models = [Model (
511
503
id = model_id ,
@@ -552,8 +544,7 @@ async def create_session(self, function_group: str, opts: Optional[dict] = None)
552
544
email = opts .get ('email' ) if opts else "" ,
553
545
functionGroup = function_group ,
554
546
preferencesEnabled = False ,
555
- requestModel = opts .get (
556
- 'model' ) if opts and 'model' in opts else DefaultModelId
547
+ requestModel = opts .get ('model' ) if opts and 'model' in opts else DefaultModelId
557
548
)
558
549
response = await make_api_request (
559
550
api_base_url = self ._api_base_url ,
@@ -586,4 +577,4 @@ async def share_function_group(
586
577
payload = payload
587
578
)
588
579
logger .debug (f"FG successfully shared: { function_group_id } with { target_user_email } " )
589
- return raw_response
580
+ return raw_response
0 commit comments