Skip to content

Commit 5295a12

Browse files
authored
Merge pull request #20 from tanaygodse/digest-issue
['digest'] issue and ['args'] issue fix
2 parents 229a0c0 + 6e8c494 commit 5295a12

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

Diff for: ai_engine_sdk/api_models/agents_json_messages.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class TaskSelectionMessage(AgentJsonMessage):
3737
text: str
3838
options: Dict[str, TaskOption]
3939

40-
4140
def get_options_keys(self) -> list[TaskOption]:
4241
return [option for option in self.options]
4342

@@ -50,7 +49,6 @@ class DataRequestMessage(AgentJsonMessage):
5049
class ConfirmationMessage(AgentJsonMessage):
5150
type: Literal[AgentJsonMessageTypes.CONFIRMATION] = AgentJsonMessageTypes.CONFIRMATION
5251
text: str
53-
model: str
5452
payload: Dict[str, Any]
5553

5654

@@ -70,14 +68,16 @@ def is_agent_json_confirmation_message(message_type: str) -> bool:
7068

7169
def is_task_selection_message(message_type: str) -> bool:
7270
union_of_type = TaskSelectionTypes
73-
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
71+
allowed_values = [literal for lit in get_args(
72+
union_of_type) for literal in get_args(lit)]
7473
return message_type.upper() in allowed_values
7574

7675

7776
def is_data_request_message(message_type: str) -> bool:
7877
union_of_type = DataRequestTypes
7978
if get_origin(union_of_type) is Union:
80-
allowed_values = [literal for lit in get_args(union_of_type) for literal in get_args(lit)]
79+
allowed_values = [literal for lit in get_args(
80+
union_of_type) for literal in get_args(lit)]
8181
elif get_origin(union_of_type) is Literal:
8282
allowed_values = get_args(union_of_type)
8383

Diff for: ai_engine_sdk/client.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
280280
'id': message['message_id'],
281281
'timestamp': message['timestamp'],
282282
'text': agent_json['text'],
283-
'options':indexed_task_options
283+
'options': indexed_task_options
284284
})
285285
)
286286
elif is_api_context_json(message_type=agent_json_type, agent_json_text=agent_json['text']):
@@ -289,8 +289,7 @@ async def get_messages(self) -> List[ApiBaseMessage]:
289289
'id': message['message_id'],
290290
'timestamp': message['timestamp'],
291291
'text': agent_json['text'],
292-
'model': agent_json['context_json']['digest'],
293-
'payload': agent_json['context_json']['args'],
292+
'payload': agent_json['context_json'],
294293
})
295294
)
296295
elif is_data_request_message(message_type=agent_json_type):
@@ -352,7 +351,7 @@ async def delete(self):
352351
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}"
353352
)
354353

355-
async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
354+
async def execute_function(self, function_ids: list[str], objective: str, context: str | None = None):
356355
await self._submit_message(
357356
payload=ApiUserMessageExecuteFunctions.model_validate({
358357
"functions": function_ids,
@@ -367,7 +366,6 @@ def __init__(self, api_key: str, options: Optional[dict] = None):
367366
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
368367
self._api_key = api_key
369368

370-
371369
####
372370
# Function groups
373371
####
@@ -579,4 +577,4 @@ async def share_function_group(
579577
payload=payload
580578
)
581579
logger.debug(f"FG successfully shared: {function_group_id} with {target_user_email}")
582-
return raw_response
580+
return raw_response

0 commit comments

Comments
 (0)