Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 131 additions & 3 deletions src/team_workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ impl TeamWorkflow {
2. AGENT DESIGN: For each subtask:
- Assign clear name/description (e.g., "DataValidator")
- Select the most suitable model (consider capabilities/task requirements)
- Set appropriate temperature (0.0-1.0): lower for deterministic tasks, higher for creative tasks
- Set max_tokens (positive integer): based on expected output length
- Craft focused system prompts with:
* Clear role definition
* Expected output format
Expand All @@ -119,11 +121,13 @@ impl TeamWorkflow {

OUTPUT REQUIREMENTS:
Your orchestration plan MUST specify:
- workers[]: Array of agent configurations (name, description, model, system_prompt)
- workers[]: Array of agent configurations (name, description, model, system_prompt, temperature, max_tokens)
- connections[]: Array of "from→to" relationships
- starting_agent: Entry point
- final_agent: Output producer

CRITICAL: Each worker MUST include temperature (0.0-1.0) and max_tokens (positive integer) fields.

DESIGN PRINCIPLES:
1. SPECIALIZATION: Each agent should have a single, well-defined responsibility
2. BALANCE: Distribute workload evenly across agents
Expand All @@ -141,13 +145,17 @@ impl TeamWorkflow {
name: "DataCollector",
description: "Gathers raw market data from APIs",
model: "data-crawler",
system_prompt: "Collect...output as JSON with [timestamp, value] pairs"
system_prompt: "Collect...output as JSON with [timestamp, value] pairs",
temperature: 0.3,
max_tokens: 2000
}},
{{
name: "TrendAnalyzer",
description: "Identifies statistical patterns",
model: "stats-v3",
system_prompt: "Input raw data...output [trend_lines, anomalies]"
system_prompt: "Input raw data...output [trend_lines, anomalies]",
temperature: 0.7,
max_tokens: 3000
}}
]
2. connections: ["DataCollector→TrendAnalyzer"]
Expand Down Expand Up @@ -470,6 +478,126 @@ fn orchestrate(
Ok(orchestration_plan)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_orchestration_plan_parsing_with_temperature_and_max_tokens() {
let json_with_required_fields = r#"
{
"workers": [
{
"name": "TestAgent",
"description": "A test agent",
"model": "test-model",
"system_prompt": "You are a test agent",
"temperature": 0.7,
"max_tokens": 2000
}
],
"connections": [],
"starting_agents": ["TestAgent"],
"output_agents": ["TestAgent"]
}
"#;

let result = TeamWorkflow::parse_orchestration_plan(json_with_required_fields);
assert!(result.is_ok(), "Should parse successfully with temperature and max_tokens");

let plan = result.unwrap();
assert_eq!(plan.workers.len(), 1);
assert_eq!(plan.workers[0].name, "TestAgent");
assert_eq!(plan.workers[0].temperature, 0.7);
assert_eq!(plan.workers[0].max_tokens, 2000);
}

#[test]
fn test_orchestration_plan_parsing_without_temperature_fails() {
let json_without_temperature = r#"
{
"workers": [
{
"name": "TestAgent",
"description": "A test agent",
"model": "test-model",
"system_prompt": "You are a test agent",
"max_tokens": 2000
}
],
"connections": [],
"starting_agents": ["TestAgent"],
"output_agents": ["TestAgent"]
}
"#;

let result = TeamWorkflow::parse_orchestration_plan(json_without_temperature);
assert!(result.is_err(), "Should fail to parse without temperature field");
}

#[test]
fn test_orchestration_plan_parsing_without_max_tokens_fails() {
let json_without_max_tokens = r#"
{
"workers": [
{
"name": "TestAgent",
"description": "A test agent",
"model": "test-model",
"system_prompt": "You are a test agent",
"temperature": 0.7
}
],
"connections": [],
"starting_agents": ["TestAgent"],
"output_agents": ["TestAgent"]
}
"#;

let result = TeamWorkflow::parse_orchestration_plan(json_without_max_tokens);
assert!(result.is_err(), "Should fail to parse without max_tokens field");
}

#[test]
fn test_worker_agent_json_schema_includes_required_fields() {
use schemars::schema_for;

let schema = schema_for!(WorkerAgent);
let schema_value = schema.as_value();

// Check that the schema includes temperature and max_tokens as required properties
let properties = schema_value
.get("properties")
.expect("Schema should have properties");

assert!(properties.get("temperature").is_some(), "Schema should include temperature field");
assert!(properties.get("max_tokens").is_some(), "Schema should include max_tokens field");
assert!(properties.get("name").is_some(), "Schema should include name field");
assert!(properties.get("description").is_some(), "Schema should include description field");
assert!(properties.get("model").is_some(), "Schema should include model field");
assert!(properties.get("system_prompt").is_some(), "Schema should include system_prompt field");

// Check that all fields are required
let required = schema_value
.get("required")
.expect("Schema should have required fields")
.as_array()
.expect("Required should be an array");

let required_strings: Vec<&str> = required
.iter()
.map(|v| v.as_str().unwrap())
.collect();

assert!(required_strings.contains(&"temperature"), "temperature should be required");
assert!(required_strings.contains(&"max_tokens"), "max_tokens should be required");
assert!(required_strings.contains(&"name"), "name should be required");
assert!(required_strings.contains(&"description"), "description should be required");
assert!(required_strings.contains(&"model"), "model should be required");
assert!(required_strings.contains(&"system_prompt"), "system_prompt should be required");
}
}

impl Display for ModelDescription {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down