-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
166 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Tuple | ||
import copy | ||
|
||
from aider.coders.base_coder import Coder | ||
"""Perform a coding task on multiple files in batches that fit the context and outpot token limits, without sending them all at once.""" | ||
class IterateCoder(Coder): | ||
coder : Coder = None | ||
original_kwargs: dict = None | ||
edit_format = "iterate" | ||
|
||
def __init__(self, main_model, io, **kwargs): | ||
super().__init__(main_model, io,**kwargs) | ||
|
||
def run_one(self, user_message, preproc): | ||
if self.coder is None: | ||
self.coder = Coder.create(main_model=self.main_model, edit_format=self.main_model.edit_format,from_coder=self,**self.original_kwargs) | ||
remaining_files_with_type_length : list[Tuple[str,bool,int]]=[] | ||
for f in self.abs_fnames: | ||
remaining_files_with_type_length.append((f, True, self.main_model.token_count(self.io.read_text(f)))) | ||
for f in self.abs_read_only_fnames: | ||
remaining_files_with_type_length.append((f,False,self.main_model.token_count(self.io.read_text(f)))) | ||
max_tokens = self.main_model.info.get('max_tokens') | ||
max_context = self.main_model.info['max_input_tokens'] | ||
max_output = self.main_model.info['max_output_tokens'] | ||
repo_token_count = self.main_model.get_repo_map_tokens() | ||
history_token_count = sum([tup[0] for tup in self.summarizer.tokenize( [msg["content"] for msg in self.done_messages])]) | ||
"""fitting input files + chat history + repo_map + files_to_send to context limit and | ||
files_to_send to the output limit. | ||
output files are assumed to be greater in size than the input files""" | ||
prev_io = self.io.yes | ||
self.io.yes = True | ||
for files_to_send_with_types in self.file_cruncher( max_context=max_context, | ||
max_output= max_tokens if max_tokens is not None else max_output, | ||
context_tokens=repo_token_count + history_token_count,remaining_files=remaining_files_with_type_length): | ||
self.coder.done_messages=copy.deepcopy(self.done_messages) #reset history of the coder to the start of the /iterate command | ||
self.coder.cur_messages=[] | ||
self.coder.abs_fnames=set([f[0] for f in files_to_send_with_types if f[1]]) | ||
self.coder.abs_read_only_fnames=set(f[0] for f in files_to_send_with_types if not f[1]) | ||
self.coder.run_one(user_message,preproc) | ||
self.io.yes = prev_io | ||
class file_cruncher: | ||
context_tokens: int | ||
max_context:int | ||
max_output:int | ||
remaining_files : list[Tuple[str,bool,int]] | ||
PADDING:int = 50 | ||
def __init__(self,max_context:int,max_output:int,context_tokens,remaining_files : list[Tuple[str,bool,int]] ): | ||
self.context_tokens = context_tokens | ||
self.max_context = max_context | ||
self.max_output = max_output | ||
self.remaining_files = sorted(remaining_files, key = lambda x: x[2]) | ||
def __iter__(self): | ||
return self | ||
def __next__(self): | ||
if len(self.remaining_files) == 0: | ||
raise StopIteration | ||
files_to_send : list[Tuple[str,bool]]= [] | ||
i:int =0 | ||
total_context= 0 | ||
total_output= 0 | ||
for file_name, type_, length in self.remaining_files: | ||
if length + (length + self.PADDING) + self.context_tokens + total_context>= self.max_context or length + self.PADDING + total_output >= self.max_output: | ||
break | ||
total_context+=length + length + self.PADDING | ||
total_output+=length + self.PADDING | ||
files_to_send.append((file_name,type_)) | ||
i+=1 | ||
if i == 0: #no file fits the limits, roll the dice and let the user deal with it | ||
f,t,_ = self.remaining_files[i] | ||
files_to_send.append((f,t)) | ||
i=1 | ||
self.remaining_files = self.remaining_files[i:] | ||
return files_to_send | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import unittest | ||
from pathlib import Path | ||
from unittest.mock import MagicMock, patch | ||
from aider.coders import Coder | ||
from aider.io import InputOutput | ||
from aider.models import Model | ||
from aider.utils import GitTemporaryDirectory | ||
|
||
|
||
class TestIterateCoder(unittest.TestCase): | ||
def setUp(self): | ||
self.GPT35 = Model("gpt-3.5-turbo") | ||
self.io = InputOutput(yes=True) | ||
# self.webbrowser_patcher = patch("aider.io.webbrowser.open") | ||
# self.mock_webbrowser = self.webbrowser_patcher.start() | ||
|
||
# Get all Python files in aider/coders directory | ||
coders_dir = Path(__file__).parent.parent.parent / "aider" / "coders" | ||
self.files = [str(f) for f in coders_dir.glob("*.py") if f.is_file()] | ||
|
||
# Create coder with all files | ||
self.coder = Coder.create( | ||
main_model=self.GPT35, | ||
io=self.io, | ||
fnames=self.files, | ||
edit_format='iterate' | ||
) | ||
|
||
def tearDown(self): | ||
# self.webbrowser_patcher.stop() | ||
return | ||
"""Tests that: | ||
- Every request retains the chat history until the /iterate command but not the history of other iterations. | ||
- Added files and history until the /iterate is unmodified. | ||
- Every file is processed(even if a single file that'll be sent with the request exceeds the limits.) and no duplicate processing | ||
""" | ||
def test_iterate_resets_history_and_processes_all_files(self): | ||
processed_files :list[str]= [] | ||
original_context:list[dict[str,str]] | ||
prev_file_names : list[str] = None | ||
# Track messages sent to LLM and files processed | ||
def mock_send(self,messages, model=None, functions=None): | ||
nonlocal original_context | ||
nonlocal processed_files | ||
nonlocal prev_file_names | ||
for original_message in original_context: | ||
assert original_message in messages, f"Chat history before start of the command is not retained." | ||
# Simulate response mentioning filename | ||
a : str="" | ||
files_message = [msg['content'] for msg in messages if "*added these files to the chat*" in msg['content']][0] | ||
from re import findall | ||
file_names = findall(r'.*\n(\S+\.py)\n```.*',files_message) | ||
for f_name in file_names: | ||
assert prev_file_names == None or f_name not in prev_file_names, "files from previous iterations hasn't been cleaned up." | ||
prev_file_names = file_names | ||
processed_files.extend(file_names) | ||
# Return minimal response | ||
self.partial_response_content = "Done." | ||
self.partial_response_function_call=dict() | ||
|
||
with GitTemporaryDirectory(): | ||
# Mock the send method | ||
with patch.object(Coder, 'send',new_callable=lambda: mock_send): | ||
self.coder.coder = Coder.create(main_model=self.coder.main_model, edit_format=self.coder.main_model.edit_format,from_coder=self.coder,**self.coder.original_kwargs) | ||
|
||
# Add initial conversation history | ||
original_context = self.coder.done_messages = [ | ||
{"role": "user", "content": "Initial conversation"}, | ||
{"role": "assistant", "content": "OK"} | ||
] | ||
|
||
# Run iterate command | ||
self.coder.run(with_message="Process all files") | ||
# Verify all files were processed | ||
input_basenames = {Path(f).name for f in self.files} | ||
processed_basenames = {Path(f).name for f in processed_files} | ||
missing = input_basenames - processed_basenames | ||
assert not missing, f"Files not processed: {missing}" | ||
|
||
# Verify history preservation and structure | ||
assert len(self.coder.done_messages) == 2, "Original chat history was modified" | ||
# Verify final file state | ||
assert len(self.coder.abs_fnames) == len(self.files), "Not all files remained in chat" | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |