diff --git a/src/team_workflow.rs b/src/team_workflow.rs index f8e5946..77f4517 100644 --- a/src/team_workflow.rs +++ b/src/team_workflow.rs @@ -108,6 +108,8 @@ impl TeamWorkflow { 2. AGENT DESIGN: For each subtask: - Assign clear name/description (e.g., "DataValidator") - Select the most suitable model (consider capabilities/task requirements) + - Set appropriate temperature (0.0-1.0): lower for deterministic tasks, higher for creative tasks + - Set max_tokens (positive integer): based on expected output length - Craft focused system prompts with: * Clear role definition * Expected output format @@ -119,11 +121,13 @@ impl TeamWorkflow { OUTPUT REQUIREMENTS: Your orchestration plan MUST specify: - - workers[]: Array of agent configurations (name, description, model, system_prompt) + - workers[]: Array of agent configurations (name, description, model, system_prompt, temperature, max_tokens) - connections[]: Array of "from→to" relationships - starting_agent: Entry point - final_agent: Output producer + CRITICAL: Each worker MUST include temperature (0.0-1.0) and max_tokens (positive integer) fields. + DESIGN PRINCIPLES: 1. SPECIALIZATION: Each agent should have a single, well-defined responsibility 2. BALANCE: Distribute workload evenly across agents @@ -141,13 +145,17 @@ impl TeamWorkflow { name: "DataCollector", description: "Gathers raw market data from APIs", model: "data-crawler", - system_prompt: "Collect...output as JSON with [timestamp, value] pairs" + system_prompt: "Collect...output as JSON with [timestamp, value] pairs", + temperature: 0.3, + max_tokens: 2000 }}, {{ name: "TrendAnalyzer", description: "Identifies statistical patterns", model: "stats-v3", - system_prompt: "Input raw data...output [trend_lines, anomalies]" + system_prompt: "Input raw data...output [trend_lines, anomalies]", + temperature: 0.7, + max_tokens: 3000 }} ] 2. connections: ["DataCollector→TrendAnalyzer"] @@ -470,6 +478,126 @@ fn orchestrate( Ok(orchestration_plan) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_orchestration_plan_parsing_with_temperature_and_max_tokens() { + let json_with_required_fields = r#" + { + "workers": [ + { + "name": "TestAgent", + "description": "A test agent", + "model": "test-model", + "system_prompt": "You are a test agent", + "temperature": 0.7, + "max_tokens": 2000 + } + ], + "connections": [], + "starting_agents": ["TestAgent"], + "output_agents": ["TestAgent"] + } + "#; + + let result = TeamWorkflow::parse_orchestration_plan(json_with_required_fields); + assert!(result.is_ok(), "Should parse successfully with temperature and max_tokens"); + + let plan = result.unwrap(); + assert_eq!(plan.workers.len(), 1); + assert_eq!(plan.workers[0].name, "TestAgent"); + assert_eq!(plan.workers[0].temperature, 0.7); + assert_eq!(plan.workers[0].max_tokens, 2000); + } + + #[test] + fn test_orchestration_plan_parsing_without_temperature_fails() { + let json_without_temperature = r#" + { + "workers": [ + { + "name": "TestAgent", + "description": "A test agent", + "model": "test-model", + "system_prompt": "You are a test agent", + "max_tokens": 2000 + } + ], + "connections": [], + "starting_agents": ["TestAgent"], + "output_agents": ["TestAgent"] + } + "#; + + let result = TeamWorkflow::parse_orchestration_plan(json_without_temperature); + assert!(result.is_err(), "Should fail to parse without temperature field"); + } + + #[test] + fn test_orchestration_plan_parsing_without_max_tokens_fails() { + let json_without_max_tokens = r#" + { + "workers": [ + { + "name": "TestAgent", + "description": "A test agent", + "model": "test-model", + "system_prompt": "You are a test agent", + "temperature": 0.7 + } + ], + "connections": [], + "starting_agents": ["TestAgent"], + "output_agents": ["TestAgent"] + } + "#; + + let result = TeamWorkflow::parse_orchestration_plan(json_without_max_tokens); + assert!(result.is_err(), "Should fail to parse without max_tokens field"); + } + + #[test] + fn test_worker_agent_json_schema_includes_required_fields() { + use schemars::schema_for; + + let schema = schema_for!(WorkerAgent); + let schema_value = schema.as_value(); + + // Check that the schema includes temperature and max_tokens as required properties + let properties = schema_value + .get("properties") + .expect("Schema should have properties"); + + assert!(properties.get("temperature").is_some(), "Schema should include temperature field"); + assert!(properties.get("max_tokens").is_some(), "Schema should include max_tokens field"); + assert!(properties.get("name").is_some(), "Schema should include name field"); + assert!(properties.get("description").is_some(), "Schema should include description field"); + assert!(properties.get("model").is_some(), "Schema should include model field"); + assert!(properties.get("system_prompt").is_some(), "Schema should include system_prompt field"); + + // Check that all fields are required + let required = schema_value + .get("required") + .expect("Schema should have required fields") + .as_array() + .expect("Required should be an array"); + + let required_strings: Vec<&str> = required + .iter() + .map(|v| v.as_str().unwrap()) + .collect(); + + assert!(required_strings.contains(&"temperature"), "temperature should be required"); + assert!(required_strings.contains(&"max_tokens"), "max_tokens should be required"); + assert!(required_strings.contains(&"name"), "name should be required"); + assert!(required_strings.contains(&"description"), "description should be required"); + assert!(required_strings.contains(&"model"), "model should be required"); + assert!(required_strings.contains(&"system_prompt"), "system_prompt should be required"); + } +} + impl Display for ModelDescription { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(