Skip to content

Commit

Permalink
Fix and Improve Docstrings
Browse files Browse the repository at this point in the history
- Fix typo in system prompt
  • Loading branch information
psaegert committed Mar 10, 2024
1 parent d06ee74 commit d3208c7
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 173 deletions.
9 changes: 6 additions & 3 deletions src/llmcoder/analyze/gpt_score_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class GPTScoreAnalyzer(Analyzer):
"""
Create a new GPTScoreAnalyzer
Analyzer that scores code using GPT-3.5 with a scoring prompt.
Parameters
----------
Expand All @@ -18,7 +18,7 @@ class GPTScoreAnalyzer(Analyzer):
scoring_prompt : str
The scoring prompt to use
reduction : str | None, optional
The reduction method to use, by default "geo"
The reduction method to use, by default "geo" (geometric mean)
verbose : bool, optional
Whether to print verbose output, by default False
"""
Expand All @@ -30,7 +30,8 @@ def __init__(self, client: OpenAI | None = None, scoring_prompt: str | None = No
self.verbose = verbose

def score_prompt(self, code_list: list[str]) -> str:
"""Concatenates the code snippets with the scoring prompt in the following format:
"""
Concatenates the code snippets with the scoring prompt in the following format:
Code snippet 1:
```python
Expand Down Expand Up @@ -156,6 +157,8 @@ def analyze(self, input: str, completion: str, context: dict[str, dict[str, floa
The input code
completion : str
The completion to analyze
context : dict[str, dict[str, float | int | str]] | None, optional
Ignored. The context of previous analyzers of the completion, by default None.
reduction : str | None, optional
The reduction method to use, by default "geo"
Expand Down
15 changes: 6 additions & 9 deletions src/llmcoder/analyze/hallucination_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@
class HallucinationAnalyzer(Analyzer):
"""
Analyzer that checks mypy errors for hallucinations.
Parameters
----------
verbose : bool
Whether to print debug messages.
"""

def __init__(self, verbose: bool = False) -> None:
"""
Initialize the SignatureAnalyzer.
Parameters
----------
verbose : bool
Whether to print debug messages.
"""
super().__init__(verbose)

def analyze(self, input: str, completion: str, context: dict[str, dict[str, float | int | str]] | None = None) -> dict:
Expand All @@ -31,7 +28,7 @@ def analyze(self, input: str, completion: str, context: dict[str, dict[str, floa
The input code.
completion : str
The completion code.
context : dict[str, dict[str, float | int | str]] | None
context : dict[str, dict[str, float | int | str]] | None, optional
The context from the previous analyzers.
Returns
Expand Down
33 changes: 15 additions & 18 deletions src/llmcoder/analyze/mypy_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,22 @@
class MypyAnalyzer(Analyzer):
"""
Analyzer that runs mypy on the code with the completion and returns the result.
"""
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False.
"""
def __init__(self, verbose: bool = False):
"""
Initializes the analyzer.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False.
"""
super().__init__(verbose=verbose)

def analyze(self,
input: str,
completion: str,
install_stubs: bool = True,
mypy_args: list[str] | None = None,
context: dict[str, dict[str, float | int | str]] | None = None) -> dict:
def analyze(
self,
input: str,
completion: str,
install_stubs: bool = True,
mypy_args: list[str] | None = None,
context: dict[str, dict[str, float | int | str]] | None = None) -> dict:

"""
Analyzes the completion using mypy.
Expand All @@ -40,10 +37,10 @@ def analyze(self,
The completion to analyze.
install_stubs : bool, optional
Whether to install missing stubs, by default True.
mypy_args : list[str], optional
mypy_args : list[str] | None, optional
Additional arguments to pass to mypy, by default None.
context : dict[str, dict[str, float | int | str]], optional
The context of the completion, by default None.
context : dict[str, dict[str, float | int | str]] | None, optional
Ignored. The context of previous analyzers of the completion, by default None.
Returns
-------
Expand Down
17 changes: 7 additions & 10 deletions src/llmcoder/analyze/signature_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
class SignatureAnalyzer(Analyzer):
"""
Analyzer that fetches the signatures and documentations of functions and classes in the code.
Parameters
----------
verbose : bool
Whether to print debug messages.
"""

def __init__(self, verbose: bool = False) -> None:
"""
Initialize the SignatureAnalyzer.
Parameters
----------
verbose : bool
Whether to print debug messages.
"""
super().__init__(verbose)

def get_imports(self, path: str, query: str | list[str] | None = None) -> Generator:
Expand Down Expand Up @@ -309,8 +306,8 @@ def analyze(self, input: str, completion: str, context: dict[str, dict[str, floa
The input code.
completion : str
The completion code.
context : dict[str, dict[str, float | int | str]] | None
The context from the previous analyzers.
context : dict[str, dict[str, float | int | str]] | None, optional
The context of previous analyzers of the completion.
Returns
-------
Expand Down
24 changes: 24 additions & 0 deletions src/llmcoder/conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@


class Conversation:
"""
A class to represent a conversation, which contains a list of messages, a score, and a list of analyses.
Parameters
----------
score : int
The score of the conversation
messages : list[dict[str, str]]
The list of messages in the conversation
analyses : list[dict[str, dict[str, float | int | str | bool]]] | None, optional
The list of analyses in the conversation, by default None
path : list[Any] | None, optional
The path of the conversation in the conversation tree, by default None
passing : bool, optional
Whether the conversation has passed all critical analyzers, by default False
"""
def __init__(
self,
score: int,
Expand Down Expand Up @@ -43,6 +59,14 @@ def add_to_path(self, choice: Any) -> "Conversation":
return self

def update_passing(self) -> "Conversation":
"""
Update the passing status of the conversation based on the critical analyzers
Returns
-------
Conversation
The conversation with the updated passing status
"""
# Print how many critical analyzers have passed
n_passed = sum(results['pass'] for results in self.analyses[-1].values() if (results['type'] == "critical" and type(results['pass']) is bool))
n_total = len([results for results in self.analyses[-1].values() if results['type'] == "critical" and type(results['pass']) is bool])
Expand Down
10 changes: 10 additions & 0 deletions src/llmcoder/conversation/priority_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@


class PriorityQueue:
"""
A priority queue for conversations, which sorts the conversations based on their scores.
Parameters
----------
conversations : Conversation | list[Conversation] | None, optional
The conversations to be added to the priority queue, by default None
backtracking : bool, optional
Whether to allow re-considering previous conversations, by default True
"""
def __init__(self, conversations: Conversation | list[Conversation] | None = None, backtracking: bool = True):
self.queue: list[Conversation] = []
self.backtracking = backtracking
Expand Down
45 changes: 25 additions & 20 deletions src/llmcoder/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def split_file(file_contents: str, min_pos: int = 1, max_pos: int = None) -> tup
The minimum position to split the file at, by default 1
max_pos : int, optional
The maximum position to split the file at, by default None
Returns
-------
tuple[str, str]
A tuple containing the first and second part of the file.
"""
if max_pos is None:
max_pos = len(file_contents) - 1
Expand Down Expand Up @@ -122,27 +127,27 @@ def sample_files_from_dir(repo_dir: str, n_samples: int = 4, file_extensions: li


class Preprocessor:
def __init__(self, dataset_name: str, tokenizer: str = "p50k_base", scraped_files_dir: str | None = None, save_pairs_dir: str | None = None, save_data_dir: str | None = None, system_prompt: str | None = None, disallowed_special_tokens: list[str] | None = None) -> None:
"""
A preprocessor for the fine-tuning data which samples files from scraped repositories, splits them into two parts and saves them in a format that can be used for fine-tuning.
"""
A preprocessor for the fine-tuning data which samples files from scraped repositories, splits them into two parts and saves them in a format that can be used for fine-tuning.
Parameters
----------
dataset_name : str
The name of the dataset.
tokenizer : str, optional
The tokenizer to use, by default "p50k_base" for gpt-3.5-turbo
scraped_files_dir : str
The directory to store the scraped files in, defaults to 'scraped_repos'.
save_pairs_dir : str
The directory to store the sampled files in, defaults to 'pairs'.
save_data_dir : str
The directory to store the preprocessed data in, defaults to 'github_mix'.
system_prompt : str
The system prompt to use, defaults to the default system prompt.
disallowed_special_tokens : list[str]
A list of disallowed special tokens, defaults to the default disallowed special tokens.
"""
Parameters
----------
dataset_name : str
The name of the dataset.
tokenizer : str, optional
The tokenizer to use, by default "p50k_base" for gpt-3.5-turbo
scraped_files_dir : str
The directory to store the scraped files in, defaults to 'scraped_repos'.
save_pairs_dir : str
The directory to store the sampled files in, defaults to 'pairs'.
save_data_dir : str
The directory to store the preprocessed data in, defaults to 'github_mix'.
system_prompt : str
The system prompt to use, defaults to the default system prompt.
disallowed_special_tokens : list[str]
A list of disallowed special tokens, defaults to the default disallowed special tokens.
"""
def __init__(self, dataset_name: str, tokenizer: str = "p50k_base", scraped_files_dir: str | None = None, save_pairs_dir: str | None = None, save_data_dir: str | None = None, system_prompt: str | None = None, disallowed_special_tokens: list[str] | None = None) -> None:
self.name = dataset_name

self.enc = tiktoken.get_encoding(tokenizer)
Expand Down
39 changes: 20 additions & 19 deletions src/llmcoder/data/scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
class GitHubScraper:
"""
A class for scraping GitHub repositories and storing them in a flat structure.
Parameters
----------
dataset_name : str
The name of the dataset to scrape repositories for.
access_token : str
A GitHub access token for authenticating with the GitHub API.
scraped_files_dir : str
The directory to store the scraped files in, defaults to 'scraped_repos'.
"""
def __init__(self, dataset_name: str, access_token: str | None = None, scraped_files_dir: str | None = None) -> None:
"""
Initialize the GitHubScraper class with a GitHub access token.
Parameters
----------
dataset_name : str
The name of the dataset to scrape repositories for.
access_token : str
A GitHub access token for authenticating with the GitHub API.
scraped_files_dir : str
The directory to store the scraped files in, defaults to 'scraped_repos'.
"""
self.name = dataset_name

self.access_token = access_token
Expand All @@ -47,9 +44,13 @@ def get_repos_with_query(self, query: str, num_repos: int = 1) -> list:
----------
query : str
A GitHub API query.
num_repos : int
The number of repositories to fetch.
Returns
-------
list
A list of repositories.
"""
if self.access_token is not None:
headers = {'Authorization': f'token {self.access_token}'}
Expand Down Expand Up @@ -150,13 +151,13 @@ def accumulate_repositories(self, repository_sets: list[list[str]] | None = None
Parameters
----------
repository_sets : list[list[str]]
A list of lists of repository URLs to scrape. Each list represents a set of repositories to scrape relating to a specific topic.
repository_sets : list[list[str]] | None, optional
A list of lists of repository URLs to scrape. Each list represents a set of repositories to scrape relating to a specific topic. If None, a default set of repositories will be used. by default None.
Returns
-------
list[str]
A list of repository URLs to scrape.
list[tuple[str, str]]
A list of tuples of (repo_url, repo_name).
"""
if repository_sets is None:
# Get the top 10 Python repositories by stars
Expand Down Expand Up @@ -230,8 +231,8 @@ def scrape_repositories(self, repos: list[tuple[str, str]] | None = None, max_n_
Parameters
----------
repos : list[tuple[str, str]]
A list of tuples of (repo_url, repo_name).
repos : list[tuple[str, str]] | None, optional
A list of tuples of (repo_url, repo_name). If None, a default set of repositories will be used. by default None.
max_n_repositories : int
The maximum number of repositories to scrape.
verbose : bool
Expand Down
Loading

0 comments on commit d3208c7

Please sign in to comment.