Skip to content

Commit d23b838

Browse files
committed
multi_function_advanced
1 parent 1768ed7 commit d23b838

File tree

4 files changed

+64
-77
lines changed

4 files changed

+64
-77
lines changed

community/gemini/src/functions/multi_function_call_advanced.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Optional
66

77
import os
8+
from src.functions.tools import get_function_declarations
89

910
class ChatMessage(BaseModel):
1011
role: str
@@ -14,81 +15,13 @@ class FunctionInputParams(BaseModel):
1415
user_content: str
1516
chat_history: Optional[List[ChatMessage]] = None
1617

17-
class WeatherInput(BaseModel):
18-
location: str
19-
20-
class HumidityInput(BaseModel):
21-
location: str
22-
23-
class AirQualityInput(BaseModel):
24-
location: str
25-
26-
@function.defn()
27-
async def get_current_weather(input: WeatherInput) -> str:
28-
log.info("get_current_weather function started", location=input.location)
29-
return 'sunny'
30-
3118
@function.defn()
32-
async def get_humidity(input: HumidityInput) -> str:
33-
log.info("get_humidity function started", location=input.location)
34-
return '65%'
35-
36-
@function.defn()
37-
async def get_air_quality(input: AirQualityInput) -> str:
38-
log.info("get_air_quality function started", location=input.location)
39-
return 'good'
40-
41-
@function.defn()
42-
async def gemini_multi_function_call_advanced(input: FunctionInputParams) :
19+
async def gemini_multi_function_call_advanced(input: FunctionInputParams):
4320
try:
4421
log.info("gemini_multi_function_call_advanced function started", input=input)
4522
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
4623

47-
functions = [
48-
{
49-
"name": "get_current_weather",
50-
"description": "Get the current weather in a given location",
51-
"parameters": {
52-
"type": "OBJECT",
53-
"properties": {
54-
"location": {
55-
"type": "STRING",
56-
"description": "The city and state, e.g. San Francisco, CA",
57-
},
58-
},
59-
"required": ["location"],
60-
}
61-
},
62-
{
63-
"name": "get_humidity",
64-
"description": "Get the current humidity in a given location",
65-
"parameters": {
66-
"type": "OBJECT",
67-
"properties": {
68-
"location": {
69-
"type": "STRING",
70-
"description": "The city and state, e.g. San Francisco, CA",
71-
},
72-
},
73-
"required": ["location"],
74-
}
75-
},
76-
{
77-
"name": "get_air_quality",
78-
"description": "Get the current air quality in a given location",
79-
"parameters": {
80-
"type": "OBJECT",
81-
"properties": {
82-
"location": {
83-
"type": "STRING",
84-
"description": "The city and state, e.g. San Francisco, CA",
85-
},
86-
},
87-
"required": ["location"],
88-
}
89-
}
90-
]
91-
24+
functions = get_function_declarations()
9225
tools = [types.Tool(function_declarations=functions)]
9326

9427
response = client.models.generate_content(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from restack_ai.function import function, log
2+
from pydantic import BaseModel
3+
from typing import List, Optional
4+
import inspect
5+
6+
class TemperatureInput(BaseModel):
7+
"""The city and state, e.g. San Francisco, CA"""
8+
location: str
9+
10+
class HumidityInput(BaseModel):
11+
"""The city and state, e.g. San Francisco, CA"""
12+
location: str
13+
14+
class AirQualityInput(BaseModel):
15+
"""The city and state, e.g. San Francisco, CA"""
16+
location: str
17+
18+
@function.defn()
19+
async def get_current_temperature(input: TemperatureInput) -> str:
20+
description = "Get the current temperature for a specific location"
21+
log.info("get_current_temperature function started", location=input.location)
22+
return '75°F'
23+
24+
@function.defn()
25+
async def get_humidity(input: HumidityInput) -> str:
26+
description = "Get the current humidity level for a specific location"
27+
log.info("get_humidity function started", location=input.location)
28+
return '65%'
29+
30+
@function.defn()
31+
async def get_air_quality(input: AirQualityInput) -> str:
32+
description = "Get the current air quality for a specific location"
33+
log.info("get_air_quality function started", location=input.location)
34+
return 'good'
35+
36+
def get_function_declarations():
37+
functions = []
38+
for func in [get_current_temperature, get_humidity, get_air_quality]:
39+
input_type = func.__annotations__['input']
40+
source = inspect.getsource(func)
41+
description = source.split('description = "')[1].split('"')[0]
42+
functions.append({
43+
"name": func.__name__,
44+
"description": description,
45+
"parameters": {
46+
"type": "OBJECT",
47+
"properties": {
48+
field_name: {
49+
"type": "STRING",
50+
"description": input_type.__doc__,
51+
} for field_name in input_type.__fields__
52+
},
53+
"required": list(input_type.__fields__.keys())
54+
}
55+
})
56+
return functions

community/gemini/src/services.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
from src.workflows.multi_function_call_advanced import GeminiMultiFunctionCallAdvancedWorkflow
1717
from src.functions.multi_function_call_advanced import gemini_multi_function_call_advanced
18-
from src.functions.multi_function_call_advanced import get_current_weather, get_humidity, get_air_quality
18+
from src.functions.tools import get_current_temperature, get_humidity, get_air_quality
1919

2020
async def main():
2121
await client.start_service(
2222
workflows= [GeminiGenerateContentWorkflow, GeminiFunctionCallWorkflow, GeminiMultiFunctionCallWorkflow, GeminiMultiFunctionCallAdvancedWorkflow],
23-
functions= [gemini_generate_content, gemini_function_call, gemini_multi_function_call, gemini_multi_function_call_advanced, get_current_weather, get_humidity, get_air_quality]
23+
functions= [gemini_generate_content, gemini_function_call, gemini_multi_function_call, gemini_multi_function_call_advanced, get_current_temperature, get_humidity, get_air_quality]
2424
)
2525

2626
def run_services():

community/gemini/src/workflows/multi_function_call_advanced.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
with import_functions():
99
from src.functions.multi_function_call_advanced import (
1010
gemini_multi_function_call_advanced,
11-
get_current_weather,
12-
get_humidity,
13-
get_air_quality,
1411
FunctionInputParams,
1512
ChatMessage
1613
)
14+
from src.functions.tools import get_current_temperature, get_humidity, get_air_quality
1715

1816
class WorkflowInputParams(BaseModel):
19-
user_content: str = "what's the weather in San Francisco?"
17+
user_content: str = "What's the weather in San Francisco?"
2018

2119
@workflow.defn()
2220
class GeminiMultiFunctionCallAdvancedWorkflow:
@@ -56,7 +54,7 @@ async def run(self, input: WorkflowInputParams):
5654
func_call = part["functionCall"]
5755
function_name = func_call["name"]
5856

59-
if function_name in {"get_current_weather", "get_humidity", "get_air_quality"}:
57+
if function_name in {"get_current_temperature", "get_humidity", "get_air_quality"}:
6058
try:
6159
result = await workflow.step(
6260
globals()[function_name],

0 commit comments

Comments
 (0)