Skip to content

Commit 717c06d

Browse files
committed
swarm
1 parent d23b838 commit 717c06d

File tree

7 files changed

+199
-28
lines changed

7 files changed

+199
-28
lines changed
Lines changed: 127 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,144 @@
11
from restack_ai.function import function, log
2-
from pydantic import BaseModel
3-
from typing import List, Optional
2+
from pydantic import BaseModel, validator
3+
from typing import List, Optional, Dict
44
import inspect
55

6-
class TemperatureInput(BaseModel):
7-
"""The city and state, e.g. San Francisco, CA"""
8-
location: str
6+
from enum import Enum
97

10-
class HumidityInput(BaseModel):
11-
"""The city and state, e.g. San Francisco, CA"""
12-
location: str
8+
class USTopCities(str, Enum):
9+
NEW_YORK = "New York, NY"
10+
LOS_ANGELES = "Los Angeles, CA"
11+
CHICAGO = "Chicago, IL"
12+
HOUSTON = "Houston, TX"
13+
PHOENIX = "Phoenix, AZ"
14+
PHILADELPHIA = "Philadelphia, PA"
15+
SAN_ANTONIO = "San Antonio, TX"
16+
SAN_DIEGO = "San Diego, CA"
17+
DALLAS = "Dallas, TX"
18+
SAN_JOSE = "San Jose, CA"
19+
AUSTIN = "Austin, TX"
20+
JACKSONVILLE = "Jacksonville, FL"
21+
FORT_WORTH = "Fort Worth, TX"
22+
COLUMBUS = "Columbus, OH"
23+
SAN_FRANCISCO = "San Francisco, CA"
24+
CHARLOTTE = "Charlotte, NC"
25+
INDIANAPOLIS = "Indianapolis, IN"
26+
SEATTLE = "Seattle, WA"
27+
DENVER = "Denver, CO"
28+
WASHINGTON_DC = "Washington, DC"
29+
BOSTON = "Boston, MA"
30+
EL_PASO = "El Paso, TX"
31+
DETROIT = "Detroit, MI"
32+
NASHVILLE = "Nashville, TN"
33+
PORTLAND = "Portland, OR"
34+
MEMPHIS = "Memphis, TN"
35+
OKLAHOMA_CITY = "Oklahoma City, OK"
36+
LAS_VEGAS = "Las Vegas, NV"
37+
LOUISVILLE = "Louisville, KY"
38+
BALTIMORE = "Baltimore, MD"
39+
MILWAUKEE = "Milwaukee, WI"
40+
ALBUQUERQUE = "Albuquerque, NM"
41+
TUCSON = "Tucson, AZ"
42+
FRESNO = "Fresno, CA"
43+
SACRAMENTO = "Sacramento, CA"
44+
MESA = "Mesa, AZ"
45+
KANSAS_CITY = "Kansas City, MO"
46+
ATLANTA = "Atlanta, GA"
47+
MIAMI = "Miami, FL"
48+
COLORADO_SPRINGS = "Colorado Springs, CO"
49+
RALEIGH = "Raleigh, NC"
50+
OMAHA = "Omaha, NE"
51+
LONG_BEACH = "Long Beach, CA"
52+
VIRGINIA_BEACH = "Virginia Beach, VA"
53+
OAKLAND = "Oakland, CA"
54+
MINNEAPOLIS = "Minneapolis, MN"
55+
TULSA = "Tulsa, OK"
56+
ARLINGTON = "Arlington, TX"
57+
TAMPA = "Tampa, FL"
58+
NEW_ORLEANS = "New Orleans, LA"
1359

14-
class AirQualityInput(BaseModel):
60+
class WeatherData(BaseModel):
61+
temperature: str
62+
humidity: str
63+
air_quality: str
64+
65+
class LocationInput(BaseModel):
1566
"""The city and state, e.g. San Francisco, CA"""
16-
location: str
67+
location: USTopCities
68+
69+
CITY_WEATHER_DATA: Dict[USTopCities, WeatherData] = {
70+
USTopCities.NEW_YORK: WeatherData(temperature="72°F", humidity="60%", air_quality="moderate"),
71+
USTopCities.LOS_ANGELES: WeatherData(temperature="75°F", humidity="65%", air_quality="good"),
72+
USTopCities.CHICAGO: WeatherData(temperature="68°F", humidity="55%", air_quality="good"),
73+
USTopCities.HOUSTON: WeatherData(temperature="82°F", humidity="75%", air_quality="moderate"),
74+
USTopCities.PHOENIX: WeatherData(temperature="95°F", humidity="25%", air_quality="good"),
75+
USTopCities.PHILADELPHIA: WeatherData(temperature="70°F", humidity="62%", air_quality="moderate"),
76+
USTopCities.SAN_ANTONIO: WeatherData(temperature="85°F", humidity="70%", air_quality="good"),
77+
USTopCities.SAN_DIEGO: WeatherData(temperature="72°F", humidity="68%", air_quality="good"),
78+
USTopCities.DALLAS: WeatherData(temperature="83°F", humidity="65%", air_quality="moderate"),
79+
USTopCities.SAN_JOSE: WeatherData(temperature="73°F", humidity="60%", air_quality="good"),
80+
USTopCities.AUSTIN: WeatherData(temperature="84°F", humidity="68%", air_quality="good"),
81+
USTopCities.JACKSONVILLE: WeatherData(temperature="80°F", humidity="75%", air_quality="moderate"),
82+
USTopCities.FORT_WORTH: WeatherData(temperature="83°F", humidity="65%", air_quality="moderate"),
83+
USTopCities.COLUMBUS: WeatherData(temperature="71°F", humidity="63%", air_quality="good"),
84+
USTopCities.SAN_FRANCISCO: WeatherData(temperature="65°F", humidity="75%", air_quality="good"),
85+
USTopCities.CHARLOTTE: WeatherData(temperature="76°F", humidity="65%", air_quality="good"),
86+
USTopCities.INDIANAPOLIS: WeatherData(temperature="72°F", humidity="64%", air_quality="moderate"),
87+
USTopCities.SEATTLE: WeatherData(temperature="62°F", humidity="80%", air_quality="good"),
88+
USTopCities.DENVER: WeatherData(temperature="70°F", humidity="45%", air_quality="good"),
89+
USTopCities.WASHINGTON_DC: WeatherData(temperature="74°F", humidity="65%", air_quality="moderate"),
90+
USTopCities.BOSTON: WeatherData(temperature="68°F", humidity="70%", air_quality="good"),
91+
USTopCities.EL_PASO: WeatherData(temperature="88°F", humidity="30%", air_quality="good"),
92+
USTopCities.DETROIT: WeatherData(temperature="69°F", humidity="65%", air_quality="moderate"),
93+
USTopCities.NASHVILLE: WeatherData(temperature="77°F", humidity="68%", air_quality="good"),
94+
USTopCities.PORTLAND: WeatherData(temperature="65°F", humidity="75%", air_quality="good"),
95+
USTopCities.MEMPHIS: WeatherData(temperature="79°F", humidity="70%", air_quality="moderate"),
96+
USTopCities.OKLAHOMA_CITY: WeatherData(temperature="80°F", humidity="60%", air_quality="good"),
97+
USTopCities.LAS_VEGAS: WeatherData(temperature="92°F", humidity="25%", air_quality="good"),
98+
USTopCities.LOUISVILLE: WeatherData(temperature="75°F", humidity="67%", air_quality="moderate"),
99+
USTopCities.BALTIMORE: WeatherData(temperature="73°F", humidity="65%", air_quality="moderate"),
100+
USTopCities.MILWAUKEE: WeatherData(temperature="65°F", humidity="70%", air_quality="good"),
101+
USTopCities.ALBUQUERQUE: WeatherData(temperature="82°F", humidity="35%", air_quality="good"),
102+
USTopCities.TUCSON: WeatherData(temperature="90°F", humidity="30%", air_quality="good"),
103+
USTopCities.FRESNO: WeatherData(temperature="85°F", humidity="45%", air_quality="moderate"),
104+
USTopCities.SACRAMENTO: WeatherData(temperature="80°F", humidity="55%", air_quality="moderate"),
105+
USTopCities.MESA: WeatherData(temperature="93°F", humidity="25%", air_quality="good"),
106+
USTopCities.KANSAS_CITY: WeatherData(temperature="75°F", humidity="65%", air_quality="good"),
107+
USTopCities.ATLANTA: WeatherData(temperature="78°F", humidity="70%", air_quality="moderate"),
108+
USTopCities.MIAMI: WeatherData(temperature="85°F", humidity="80%", air_quality="moderate"),
109+
USTopCities.COLORADO_SPRINGS: WeatherData(temperature="68°F", humidity="45%", air_quality="good"),
110+
USTopCities.RALEIGH: WeatherData(temperature="75°F", humidity="68%", air_quality="good"),
111+
USTopCities.OMAHA: WeatherData(temperature="73°F", humidity="65%", air_quality="good"),
112+
USTopCities.LONG_BEACH: WeatherData(temperature="74°F", humidity="70%", air_quality="moderate"),
113+
USTopCities.VIRGINIA_BEACH: WeatherData(temperature="76°F", humidity="75%", air_quality="good"),
114+
USTopCities.OAKLAND: WeatherData(temperature="68°F", humidity="75%", air_quality="good"),
115+
USTopCities.MINNEAPOLIS: WeatherData(temperature="65°F", humidity="65%", air_quality="good"),
116+
USTopCities.TULSA: WeatherData(temperature="78°F", humidity="65%", air_quality="good"),
117+
USTopCities.ARLINGTON: WeatherData(temperature="83°F", humidity="65%", air_quality="moderate"),
118+
USTopCities.TAMPA: WeatherData(temperature="83°F", humidity="75%", air_quality="moderate"),
119+
USTopCities.NEW_ORLEANS: WeatherData(temperature="82°F", humidity="80%", air_quality="moderate")
120+
}
17121

18122
@function.defn()
19-
async def get_current_temperature(input: TemperatureInput) -> str:
123+
async def get_current_temperature(input: LocationInput) -> str:
20124
description = "Get the current temperature for a specific location"
21125
log.info("get_current_temperature function started", location=input.location)
22-
return '75°F'
126+
weather_data = CITY_WEATHER_DATA.get(input.location, WeatherData(temperature="75°F", humidity="65%", air_quality="good"))
127+
return weather_data.temperature
23128

24129
@function.defn()
25-
async def get_humidity(input: HumidityInput) -> str:
130+
async def get_humidity(input: LocationInput) -> str:
26131
description = "Get the current humidity level for a specific location"
27132
log.info("get_humidity function started", location=input.location)
28-
return '65%'
133+
weather_data = CITY_WEATHER_DATA.get(input.location, WeatherData(temperature="75°F", humidity="65%", air_quality="good"))
134+
return weather_data.humidity
29135

30136
@function.defn()
31-
async def get_air_quality(input: AirQualityInput) -> str:
137+
async def get_air_quality(input: LocationInput) -> str:
32138
description = "Get the current air quality for a specific location"
33139
log.info("get_air_quality function started", location=input.location)
34-
return 'good'
140+
weather_data = CITY_WEATHER_DATA.get(input.location, WeatherData(temperature="75°F", humidity="65%", air_quality="good"))
141+
return weather_data.air_quality
35142

36143
def get_function_declarations():
37144
functions = []
@@ -45,12 +152,13 @@ def get_function_declarations():
45152
"parameters": {
46153
"type": "OBJECT",
47154
"properties": {
48-
field_name: {
155+
"location": {
49156
"type": "STRING",
50157
"description": input_type.__doc__,
51-
} for field_name in input_type.__fields__
158+
"enum": [city.value for city in USTopCities]
159+
}
52160
},
53-
"required": list(input_type.__fields__.keys())
161+
"required": ["location"]
54162
}
55163
})
56164
return functions

community/gemini/src/services.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from watchfiles import run_process
33
import webbrowser
44
import os
5+
from restack_ai.restack import ServiceOptions
56

67
from src.client import client
78
from src.functions.generate_content import gemini_generate_content
@@ -17,10 +18,33 @@
1718
from src.functions.multi_function_call_advanced import gemini_multi_function_call_advanced
1819
from src.functions.tools import get_current_temperature, get_humidity, get_air_quality
1920

21+
from src.workflows.swarm import GeminiSwarmWorkflow
22+
2023
async def main():
21-
await client.start_service(
22-
workflows= [GeminiGenerateContentWorkflow, GeminiFunctionCallWorkflow, GeminiMultiFunctionCallWorkflow, GeminiMultiFunctionCallAdvancedWorkflow],
23-
functions= [gemini_generate_content, gemini_function_call, gemini_multi_function_call, gemini_multi_function_call_advanced, get_current_temperature, get_humidity, get_air_quality]
24+
await asyncio.gather(
25+
client.start_service(
26+
workflows=[GeminiGenerateContentWorkflow, GeminiFunctionCallWorkflow, GeminiMultiFunctionCallWorkflow, GeminiMultiFunctionCallAdvancedWorkflow, GeminiSwarmWorkflow],
27+
functions=[],
28+
options=ServiceOptions(
29+
max_concurrent_workflow_runs=1000
30+
)
31+
),
32+
client.start_service(
33+
task_queue="tools",
34+
functions=[get_current_temperature, get_humidity, get_air_quality],
35+
options=ServiceOptions(
36+
rate_limit=10,
37+
max_concurrent_function_runs=10
38+
)
39+
),
40+
client.start_service(
41+
task_queue="gemini",
42+
functions=[gemini_generate_content, gemini_function_call, gemini_multi_function_call, gemini_multi_function_call_advanced],
43+
options=ServiceOptions(
44+
rate_limit=5,
45+
max_concurrent_function_runs=3
46+
)
47+
)
2448
)
2549

2650
def run_services():

community/gemini/src/workflows/function_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ async def run(self, input: WorkflowInputParams):
1919
start_to_close_timeout=timedelta(seconds=120),
2020
retry_policy=RetryPolicy(
2121
maximum_attempts=1
22-
)
22+
),
23+
task_queue="gemini"
2324
)
2425
log.info("GeminiFunctionCallWorkflow completed", result=result)
2526
return result

community/gemini/src/workflows/generate_content.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ async def run(self, input: WorkflowInputParams):
1919
start_to_close_timeout=timedelta(seconds=120),
2020
retry_policy=RetryPolicy(
2121
maximum_attempts=1
22-
)
22+
),
23+
task_queue="gemini"
2324
)
2425
log.info("GeminiGenerateContentWorkflow completed", result=result)
2526
return result

community/gemini/src/workflows/multi_function_call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ async def run(self, input: WorkflowInputParams):
1919
start_to_close_timeout=timedelta(seconds=120),
2020
retry_policy=RetryPolicy(
2121
maximum_attempts=1
22-
)
22+
),
23+
task_queue="gemini"
2324
)
2425
log.info("GeminiMultiFunctionCallWorkflow completed", result=result)
2526
return result

community/gemini/src/workflows/multi_function_call_advanced.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from src.functions.tools import get_current_temperature, get_humidity, get_air_quality
1515

16-
class WorkflowInputParams(BaseModel):
16+
class MultiFunctionCallAdvancedInputParams(BaseModel):
1717
user_content: str = "What's the weather in San Francisco?"
1818

1919
@workflow.defn()
@@ -22,7 +22,7 @@ def __init__(self):
2222
self.chat_history = []
2323

2424
@workflow.run
25-
async def run(self, input: WorkflowInputParams):
25+
async def run(self, input: MultiFunctionCallAdvancedInputParams):
2626
log.info("GeminiMultiFunctionCallAdvancedWorkflow started", input=input)
2727

2828
current_content = input.user_content
@@ -37,7 +37,8 @@ async def run(self, input: WorkflowInputParams):
3737
chat_history=self.chat_history
3838
),
3939
start_to_close_timeout=timedelta(seconds=120),
40-
retry_policy=RetryPolicy(maximum_attempts=1)
40+
retry_policy=RetryPolicy(maximum_attempts=2),
41+
task_queue="gemini"
4142
)
4243

4344
if not result or not isinstance(result, dict):
@@ -59,6 +60,7 @@ async def run(self, input: WorkflowInputParams):
5960
result = await workflow.step(
6061
globals()[function_name],
6162
func_call["args"],
63+
task_queue="tools",
6264
retry_policy=RetryPolicy(maximum_attempts=1)
6365
)
6466
function_results.append(f"{function_name} result: {str(result)}")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from restack_ai.workflow import workflow, import_functions, log, RetryPolicy, workflow_info
2+
from pydantic import BaseModel
3+
from datetime import timedelta
4+
import asyncio
5+
from src.functions.tools import USTopCities
6+
from src.workflows.multi_function_call_advanced import GeminiMultiFunctionCallAdvancedWorkflow, MultiFunctionCallAdvancedInputParams
7+
8+
class WorkflowInputParams(BaseModel):
9+
num_cities: int = 50
10+
11+
@workflow.defn()
12+
class GeminiSwarmWorkflow:
13+
@workflow.run
14+
async def run(self, input: WorkflowInputParams):
15+
parent_workflow_id = workflow_info().workflow_id
16+
17+
# Get all available cities from USTopCities enum
18+
all_cities = [city.value for city in USTopCities]
19+
20+
# Take the first n cities based on input
21+
selected_cities = all_cities[:input.num_cities]
22+
23+
results_tasks = await asyncio.gather(*[
24+
workflow.child_execute(
25+
GeminiMultiFunctionCallAdvancedWorkflow,
26+
input=MultiFunctionCallAdvancedInputParams(user_content=f"What's the weather in {city}?"),
27+
workflow_id=f"{parent_workflow_id}-child-{city.replace(', ', '-')}"
28+
) for city in selected_cities
29+
])
30+
31+
results = [{"city": city, "result": result} for city, result in zip(selected_cities, results_tasks)]
32+
33+
log.info("GeminiSwarmWorkflow completed", results=results)
34+
return results

0 commit comments

Comments
 (0)