Skip to content

Commit 229a0c0

Browse files
authored
Merge pull request #15 from fetchai/feat/start-session-executing-function
session: execute function
2 parents 631cf1e + a08f118 commit 229a0c0

File tree

8 files changed

+225
-13
lines changed

8 files changed

+225
-13
lines changed

README.md

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,42 @@ session = await ai_engine.create_session(function_group=public_group.uuid)
8686
```python
8787
await session.start(objective)
8888
```
89-
90-
89+
9190
#### Querying new messages
9291

93-
You might want to query new messages regularly ...
94-
95-
96-
92+
You might want to query new messages regularly ...
9793
```python
98-
9994
while True:
10095
messages: list[ApiBaseMessage] = await session.get_messages()
10196
# throttling
102-
sleep(3)
97+
sleep(3)
10398
```
104-
10599

100+
#### Execution a function on demand.
101+
This is the first message that should be sent to the AI Engine for execution the function/s of your choice.
102+
The main difference in here it is the AI Engine won't search, therefore decide for you, what is the apt function to fulfill your needs.
103+
104+
It contains the list of function-ids you want to execute and a function group (for secondary function picks).
105+
106+
Currently only supported by Next Generation personality.
107+
Don't use this if you already sent 'start' message.
108+
109+
```python
110+
# init the AI Engine client
111+
from ai_engine_sdk import AiEngine
112+
ai_engine: AiEngine = AiEngine(api_key)
113+
# Create (do not start) a Session
114+
session = await ai_engine.create_session(function_group=function_group.uuid)
115+
116+
# Execute function. You will receive no response.
117+
await session.execute_function(function_ids=[function_uuid], objective="", context="")
118+
119+
# In order to get some feedback, gather the messages as regular.
120+
while True:
121+
messages: list[ApiBaseMessage] = await session.get_messages()
122+
# throttling
123+
sleep(3)
124+
```
106125
#### Checking the type of the new message
107126

108127
There are 5 different types of messages which are generated by the AI Engine and the SDK implements methods for checking the type of the respective new <code>Message</code>:

ai_engine_sdk/api_models/api_models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ApiMessagePayloadTypes(str, Enum):
88
START = "start"
99
USER_JSON = "user_json"
1010
USER_MESSAGE = "user_message"
11-
11+
EXECUTE_FUNCTIONS = "execute_functions"
1212

1313
class ApiMessagePayload(BaseModel):
1414
session_id: str
@@ -55,6 +55,15 @@ class ApiUserMessageMessage(ApiMessagePayload):
5555
user_message: str
5656

5757

58+
class ApiUserMessageExecuteFunctions(ApiMessagePayload):
59+
type: Literal[ApiMessagePayloadTypes.EXECUTE_FUNCTIONS] = ApiMessagePayloadTypes.EXECUTE_FUNCTIONS
60+
61+
functions: list[str]
62+
objective: str
63+
context: str
64+
65+
66+
5867
# -----------
5968

6069
# class ApiNewSessionResponse(BaseModel):

ai_engine_sdk/client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .api_models.api_models import (
2828
ApiNewSessionRequest,
2929
is_api_context_json,
30-
ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage
30+
ApiStartMessage, ApiMessagePayload, ApiUserJsonMessage, ApiUserMessageMessage, ApiUserMessageExecuteFunctions
3131
)
3232
from .api_models.parsing_utils import get_indexed_task_options_from_raw_api_response
3333
from .llm_models import (
@@ -352,12 +352,22 @@ async def delete(self):
352352
endpoint=f"/v1beta1/engine/chat/sessions/{self.session_id}"
353353
)
354354

355+
async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
356+
await self._submit_message(
357+
payload=ApiUserMessageExecuteFunctions.model_validate({
358+
"functions": function_ids,
359+
"objective": objective,
360+
"context": context or "",
361+
'session_id': self.session_id,
362+
})
363+
)
355364

356365
class AiEngine:
357366
def __init__(self, api_key: str, options: Optional[dict] = None):
358367
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
359368
self._api_key = api_key
360369

370+
361371
####
362372
# Function groups
363373
####
@@ -464,7 +474,7 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
464474
if "functions" in raw_response:
465475
list(
466476
map(
467-
lambda function_name: FunctionGroupFunctions.parse_obj({"name": function_name}),
477+
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
468478
raw_response["functions"]
469479
)
470480
)

examples/execute_function.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
import asyncio
3+
import os
4+
from pprint import pprint
5+
6+
from faker.utils.decorators import lowercase
7+
8+
from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage
9+
from ai_engine_sdk.client import Session
10+
from tests.conftest import function_groups
11+
12+
13+
async def main(
14+
target_environment: str,
15+
agentverse_api_key: str,
16+
function_uuid: str,
17+
function_group_uuid: str
18+
):
19+
# Request from cli args.
20+
options = {}
21+
if target_environment:
22+
options = {"api_base_url": target_environment}
23+
24+
ai_engine = AiEngine(api_key=agentverse_api_key, options=options)
25+
26+
session: Session = await ai_engine.create_session(function_group=function_group_uuid)
27+
await session.execute_function(function_ids=[function_uuid], objective="", context="")
28+
29+
try:
30+
empty_count = 0
31+
session_ended = False
32+
33+
print("Waiting for execution:")
34+
while empty_count < 100:
35+
messages: list[ApiBaseMessage] = await session.get_messages()
36+
if messages:
37+
pprint(messages)
38+
if any((msg.type.lower() == "stop" for msg in messages)):
39+
print("DONE")
40+
break
41+
if len(messages) % 10 == 0:
42+
print("Wait...")
43+
if len(messages) == 0:
44+
empty_count += 1
45+
else:
46+
empty_count = 0
47+
48+
49+
except Exception as ex:
50+
pprint(ex)
51+
raise
52+
53+
if __name__ == '__main__':
54+
from dotenv import load_dotenv
55+
load_dotenv()
56+
api_key = os.getenv("AV_API_KEY", "")
57+
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument(
60+
"-e",
61+
"--target_environment",
62+
type=str,
63+
required=False,
64+
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
65+
)
66+
parser.add_argument(
67+
"-fg",
68+
"--function_group_uuid",
69+
type=str,
70+
required=True,
71+
)
72+
parser.add_argument(
73+
"-f",
74+
"--function_uuid",
75+
type=str,
76+
required=True,
77+
)
78+
args = parser.parse_args()
79+
80+
result = asyncio.run(
81+
main(
82+
agentverse_api_key=api_key,
83+
target_environment=args.target_environment,
84+
function_group_uuid=args.function_group_uuid,
85+
function_uuid=args.function_uuid
86+
)
87+
)
88+
pprint(result)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
import asyncio
3+
import os
4+
from pprint import pprint
5+
6+
from ai_engine_sdk import FunctionGroup, AiEngine
7+
from tests.integration.test_ai_engine_client import api_key
8+
9+
10+
async def main(
11+
function_group_name: str,
12+
agentverse_api_key: str,
13+
target_environment: str | None = None,
14+
):
15+
# Request from cli args.
16+
options = {}
17+
if target_environment:
18+
options = {"api_base_url": target_environment}
19+
20+
ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options)
21+
function_groups: list[FunctionGroup] = await ai_engine.get_function_groups()
22+
23+
target_function_group = next((g for g in function_groups if g.name == function_group_name), None)
24+
if target_function_group is None:
25+
raise Exception(f'Could not find "{target_function_group}" function group.')
26+
27+
return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid)
28+
29+
30+
31+
if __name__ == "__main__":
32+
from dotenv import load_dotenv
33+
load_dotenv()
34+
api_key = os.getenv("AV_API_KEY", "")
35+
36+
# Parse CLI arguments
37+
parser = argparse.ArgumentParser()
38+
39+
parser.add_argument(
40+
"-e",
41+
"--target_environment",
42+
type=str,
43+
required=False,
44+
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
45+
)
46+
parser.add_argument(
47+
"-fgn",
48+
"--fg_name",
49+
type=str,
50+
required=True,
51+
)
52+
args = parser.parse_args()
53+
54+
target_environment = args.target_environment
55+
56+
res = asyncio.run(
57+
main(
58+
agentverse_api_key=api_key,
59+
function_group_name=args.fg_name,
60+
target_environment=args.target_environment
61+
)
62+
)
63+
pprint(res)

tests/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,17 @@ async def function_groups(ai_engine_client) -> list[FunctionGroup]:
3434
# session: Session = await ai_engine_client.create_session(
3535
# function_group=function_groups, opts={"model": "next-gen"}
3636
# )
37-
# return session
37+
# return session
38+
39+
40+
@pytest.fixture(scope="session")
41+
def valid_public_function_uuid() -> str:
42+
# TODO: Do it programmatically (when test fails bc of it will be good moment)
43+
# 'Cornerstone Software' from Public fg and staging
44+
return "312712ae-eb70-42f7-bb5a-ad21ce6d73c3"
45+
46+
47+
@pytest.fixture(scope="session")
48+
def public_function_group() -> FunctionGroup:
49+
# TODO: Do it programmatically (when test fails bc of it will be good moment)
50+
return FunctionGroup(uuid="e504eabb-4bc7-458d-aa8c-7c3748f8952c", name="Public", isPrivate=False)

tests/integration/test_ai_engine_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ async def test_create_session(self, ai_engine_client: AiEngine):
6060
# await ai_engine_client.delete_function_group()
6161

6262

63+
@pytest.mark.asyncio
64+
async def test_execute_function(self, ai_engine_client: AiEngine, public_function_group: FunctionGroup, valid_public_function_uuid: str):
65+
session: Session = await ai_engine_client.create_session(function_group=public_function_group.uuid)
66+
result = await session.execute_function(
67+
function_ids=[valid_public_function_uuid],
68+
objective="Test software",
69+
context=""
70+
)
71+
72+
6373
@pytest.mark.asyncio
6474
async def test_create_function_group_and_list_them(self, ai_engine_client: AiEngine):
6575
name = fake.company()

0 commit comments

Comments
 (0)