|
| 1 | +""" |
| 2 | +TARS v1 Progressive Context Manager |
| 3 | +================================== |
| 4 | +
|
| 5 | +Manages context loading and optimization for better AI responses: |
| 6 | +- Progressive context loading (summary → details) |
| 7 | +- Smart context filtering and prioritization |
| 8 | +- Context window optimization |
| 9 | +- Hierarchical information structure |
| 10 | +""" |
| 11 | + |
| 12 | +import logging |
| 13 | +from typing import Dict, Any, List, Optional, Tuple |
| 14 | +from dataclasses import dataclass |
| 15 | +import re |
| 16 | +from pathlib import Path |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ContextItem: |
| 22 | + """Represents a piece of context with metadata.""" |
| 23 | + content: str |
| 24 | + priority: float # 0.0 to 1.0 |
| 25 | + context_type: str # 'summary', 'detail', 'code', 'documentation', etc. |
| 26 | + source: str |
| 27 | + token_estimate: int |
| 28 | + relevance_score: float = 0.0 |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class ContextLevel: |
| 33 | + """Represents a level in the progressive context hierarchy.""" |
| 34 | + level_name: str |
| 35 | + description: str |
| 36 | + items: List[ContextItem] |
| 37 | + token_limit: int |
| 38 | + min_items: int = 1 |
| 39 | + max_items: int = 10 |
| 40 | + |
| 41 | + |
| 42 | +class ProgressiveContextManager: |
| 43 | + """ |
| 44 | + Manages progressive context loading for better AI responses. |
| 45 | + |
| 46 | + Uses a hierarchical approach: |
| 47 | + 1. High-level summaries (always included) |
| 48 | + 2. Medium-detail explanations (included if space allows) |
| 49 | + 3. Detailed code/documentation (included only if highly relevant) |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__(self, max_context_tokens: int = 4000): |
| 53 | + self.max_context_tokens = max_context_tokens |
| 54 | + self.reserved_tokens = int(max_context_tokens * 0.3) # Reserve 30% for response |
| 55 | + self.available_tokens = max_context_tokens - self.reserved_tokens |
| 56 | + |
| 57 | + # Define progressive levels |
| 58 | + self.context_levels = [ |
| 59 | + ContextLevel( |
| 60 | + level_name="summary", |
| 61 | + description="High-level summaries and overviews", |
| 62 | + items=[], |
| 63 | + token_limit=int(self.available_tokens * 0.4), # 40% for summaries |
| 64 | + min_items=1, |
| 65 | + max_items=5 |
| 66 | + ), |
| 67 | + ContextLevel( |
| 68 | + level_name="details", |
| 69 | + description="Medium-detail explanations and key information", |
| 70 | + items=[], |
| 71 | + token_limit=int(self.available_tokens * 0.4), # 40% for details |
| 72 | + min_items=0, |
| 73 | + max_items=8 |
| 74 | + ), |
| 75 | + ContextLevel( |
| 76 | + level_name="specifics", |
| 77 | + description="Detailed code, documentation, and specific examples", |
| 78 | + items=[], |
| 79 | + token_limit=int(self.available_tokens * 0.2), # 20% for specifics |
| 80 | + min_items=0, |
| 81 | + max_items=3 |
| 82 | + ) |
| 83 | + ] |
| 84 | + |
| 85 | + def add_context_item( |
| 86 | + self, |
| 87 | + content: str, |
| 88 | + context_type: str, |
| 89 | + source: str, |
| 90 | + priority: float = 0.5, |
| 91 | + level: str = "details" |
| 92 | + ) -> None: |
| 93 | + """Add a context item to the appropriate level.""" |
| 94 | + try: |
| 95 | + # Estimate tokens |
| 96 | + token_estimate = len(content.split()) * 1.3 |
| 97 | + |
| 98 | + # Create context item |
| 99 | + item = ContextItem( |
| 100 | + content=content, |
| 101 | + priority=priority, |
| 102 | + context_type=context_type, |
| 103 | + source=source, |
| 104 | + token_estimate=int(token_estimate) |
| 105 | + ) |
| 106 | + |
| 107 | + # Add to appropriate level |
| 108 | + for context_level in self.context_levels: |
| 109 | + if context_level.level_name == level: |
| 110 | + context_level.items.append(item) |
| 111 | + break |
| 112 | + |
| 113 | + except Exception as e: |
| 114 | + logger.warning(f"Error adding context item: {e}") |
| 115 | + |
| 116 | + def add_repository_context(self, repo_content: str, repository_id: str) -> None: |
| 117 | + """Add repository context with progressive detail levels.""" |
| 118 | + try: |
| 119 | + lines = repo_content.split('\n') |
| 120 | + |
| 121 | + # Level 1: High-level summary |
| 122 | + file_count = len([l for l in lines if l.startswith('File: ')]) |
| 123 | + languages = set() |
| 124 | + key_files = [] |
| 125 | + |
| 126 | + for line in lines[:50]: |
| 127 | + if 'File: ' in line: |
| 128 | + file_path = line.replace('File: ', '').strip() |
| 129 | + if file_path: |
| 130 | + ext = Path(file_path).suffix |
| 131 | + if ext: |
| 132 | + languages.add(ext) |
| 133 | + |
| 134 | + # Identify key files |
| 135 | + if any(key in file_path.lower() for key in ['main', 'index', 'app', 'readme', 'config']): |
| 136 | + key_files.append(file_path) |
| 137 | + |
| 138 | + summary = f"Repository: {repository_id}\n" |
| 139 | + summary += f"Files: {file_count}\n" |
| 140 | + summary += f"Languages: {', '.join(sorted(languages)) if languages else 'Unknown'}\n" |
| 141 | + summary += f"Key Files: {', '.join(key_files[:5])}" |
| 142 | + |
| 143 | + self.add_context_item( |
| 144 | + content=summary, |
| 145 | + context_type="repository_summary", |
| 146 | + source=repository_id, |
| 147 | + priority=0.9, |
| 148 | + level="summary" |
| 149 | + ) |
| 150 | + |
| 151 | + # Level 2: Code structure details |
| 152 | + code_elements = [] |
| 153 | + current_file = None |
| 154 | + |
| 155 | + for line in lines[:200]: |
| 156 | + if line.startswith('File: '): |
| 157 | + current_file = line.replace('File: ', '').strip() |
| 158 | + elif any(pattern in line for pattern in ['def ', 'class ', 'function ', 'import ', 'interface ']): |
| 159 | + if current_file: |
| 160 | + code_elements.append(f"{current_file}: {line.strip()}") |
| 161 | + |
| 162 | + if len(code_elements) >= 15: |
| 163 | + break |
| 164 | + |
| 165 | + if code_elements: |
| 166 | + details = "Code Structure:\n" + '\n'.join(code_elements[:10]) |
| 167 | + self.add_context_item( |
| 168 | + content=details, |
| 169 | + context_type="code_structure", |
| 170 | + source=repository_id, |
| 171 | + priority=0.7, |
| 172 | + level="details" |
| 173 | + ) |
| 174 | + |
| 175 | + # Level 3: Specific code snippets (only if highly relevant) |
| 176 | + code_snippets = [] |
| 177 | + current_snippet = [] |
| 178 | + in_code_block = False |
| 179 | + |
| 180 | + for line in lines[:500]: |
| 181 | + if any(pattern in line for pattern in ['def ', 'class ', 'function ']): |
| 182 | + if current_snippet: |
| 183 | + code_snippets.append('\n'.join(current_snippet)) |
| 184 | + current_snippet = [] |
| 185 | + current_snippet.append(line) |
| 186 | + in_code_block = True |
| 187 | + elif in_code_block and line.strip(): |
| 188 | + current_snippet.append(line) |
| 189 | + if len(current_snippet) >= 8: # Limit snippet size |
| 190 | + code_snippets.append('\n'.join(current_snippet)) |
| 191 | + current_snippet = [] |
| 192 | + in_code_block = False |
| 193 | + elif in_code_block and not line.strip(): |
| 194 | + if current_snippet: |
| 195 | + code_snippets.append('\n'.join(current_snippet)) |
| 196 | + current_snippet = [] |
| 197 | + in_code_block = False |
| 198 | + |
| 199 | + if len(code_snippets) >= 3: |
| 200 | + break |
| 201 | + |
| 202 | + if current_snippet: |
| 203 | + code_snippets.append('\n'.join(current_snippet)) |
| 204 | + |
| 205 | + for i, snippet in enumerate(code_snippets[:2]): |
| 206 | + self.add_context_item( |
| 207 | + content=f"Code Example {i+1}:\n{snippet}", |
| 208 | + context_type="code_snippet", |
| 209 | + source=repository_id, |
| 210 | + priority=0.5, |
| 211 | + level="specifics" |
| 212 | + ) |
| 213 | + |
| 214 | + except Exception as e: |
| 215 | + logger.warning(f"Error adding repository context: {e}") |
| 216 | + |
| 217 | + def calculate_relevance_scores(self, query: str) -> None: |
| 218 | + """Calculate relevance scores for all context items based on query.""" |
| 219 | + try: |
| 220 | + query_words = set(query.lower().split()) |
| 221 | + |
| 222 | + for level in self.context_levels: |
| 223 | + for item in level.items: |
| 224 | + item.relevance_score = self._calculate_item_relevance(item, query_words) |
| 225 | + |
| 226 | + except Exception as e: |
| 227 | + logger.warning(f"Error calculating relevance scores: {e}") |
| 228 | + |
| 229 | + def _calculate_item_relevance(self, item: ContextItem, query_words: set) -> float: |
| 230 | + """Calculate relevance score for a specific item.""" |
| 231 | + content_words = set(item.content.lower().split()) |
| 232 | + |
| 233 | + # Base relevance from word overlap |
| 234 | + if query_words: |
| 235 | + overlap = len(query_words.intersection(content_words)) |
| 236 | + relevance = overlap / len(query_words) |
| 237 | + else: |
| 238 | + relevance = item.priority |
| 239 | + |
| 240 | + # Boost based on context type |
| 241 | + type_boosts = { |
| 242 | + 'repository_summary': 0.2, |
| 243 | + 'code_structure': 0.15, |
| 244 | + 'code_snippet': 0.1, |
| 245 | + 'documentation': 0.15, |
| 246 | + 'error_handling': 0.1 |
| 247 | + } |
| 248 | + |
| 249 | + relevance += type_boosts.get(item.context_type, 0.0) |
| 250 | + |
| 251 | + # Boost based on priority |
| 252 | + relevance = (relevance + item.priority) / 2 |
| 253 | + |
| 254 | + return min(relevance, 1.0) |
| 255 | + |
| 256 | + def build_progressive_context(self, query: str) -> Tuple[str, Dict[str, Any]]: |
| 257 | + """Build progressive context string with optimization metrics.""" |
| 258 | + try: |
| 259 | + # Calculate relevance scores |
| 260 | + self.calculate_relevance_scores(query) |
| 261 | + |
| 262 | + # Sort items within each level by relevance and priority |
| 263 | + for level in self.context_levels: |
| 264 | + level.items.sort(key=lambda x: (x.relevance_score + x.priority) / 2, reverse=True) |
| 265 | + |
| 266 | + context_parts = [f"User Query: {query}\n"] |
| 267 | + total_tokens = len(query.split()) * 1.3 |
| 268 | + items_included = {"summary": 0, "details": 0, "specifics": 0} |
| 269 | + |
| 270 | + # Progressive inclusion by level |
| 271 | + for level in self.context_levels: |
| 272 | + level_tokens = 0 |
| 273 | + level_items = 0 |
| 274 | + |
| 275 | + context_parts.append(f"\n📋 {level.description.title()}:") |
| 276 | + |
| 277 | + for item in level.items: |
| 278 | + # Check if we can include this item |
| 279 | + if (level_tokens + item.token_estimate <= level.token_limit and |
| 280 | + total_tokens + item.token_estimate <= self.available_tokens and |
| 281 | + level_items < level.max_items): |
| 282 | + |
| 283 | + context_parts.append(f"• {item.content}") |
| 284 | + level_tokens += item.token_estimate |
| 285 | + total_tokens += item.token_estimate |
| 286 | + level_items += 1 |
| 287 | + items_included[level.level_name] += 1 |
| 288 | + |
| 289 | + # Stop if we've reached minimum items for this level |
| 290 | + if level_items >= level.min_items and total_tokens > self.available_tokens * 0.8: |
| 291 | + break |
| 292 | + |
| 293 | + if level_items == 0 and level.min_items > 0: |
| 294 | + context_parts.append("• (No relevant information available)") |
| 295 | + |
| 296 | + # Add optimization summary |
| 297 | + context_parts.append(f"\n📊 Context Optimization:") |
| 298 | + context_parts.append(f"• Tokens Used: {int(total_tokens)} / {self.available_tokens}") |
| 299 | + context_parts.append(f"• Utilization: {(total_tokens / self.available_tokens) * 100:.1f}%") |
| 300 | + context_parts.append(f"• Items Included: {sum(items_included.values())}") |
| 301 | + |
| 302 | + context_parts.append("\n💡 Please provide a comprehensive response based on the above progressive context.") |
| 303 | + |
| 304 | + context_string = '\n'.join(context_parts) |
| 305 | + |
| 306 | + # Metrics for debugging |
| 307 | + metrics = { |
| 308 | + "total_tokens": int(total_tokens), |
| 309 | + "available_tokens": self.available_tokens, |
| 310 | + "utilization_percent": (total_tokens / self.available_tokens) * 100, |
| 311 | + "items_included": items_included, |
| 312 | + "levels_used": len([l for l in self.context_levels if any(i.relevance_score > 0 for i in l.items)]) |
| 313 | + } |
| 314 | + |
| 315 | + return context_string, metrics |
| 316 | + |
| 317 | + except Exception as e: |
| 318 | + logger.error(f"Error building progressive context: {e}") |
| 319 | + return f"User Query: {query}\n\nError building context: {str(e)}", {} |
| 320 | + |
| 321 | + def clear(self) -> None: |
| 322 | + """Clear all context items.""" |
| 323 | + for level in self.context_levels: |
| 324 | + level.items.clear() |
0 commit comments