Skip to content

Commit 26100a9

Browse files
rsrbkWendong-Fan
andauthored
feat: Retrive recent pull requests for GithubToolkit (#582)
Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong <[email protected]>
1 parent d0e3722 commit 26100a9

File tree

3 files changed

+210
-11
lines changed

3 files changed

+210
-11
lines changed

camel/toolkits/github_toolkit.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
from dataclasses import dataclass
17+
from datetime import datetime, timedelta
1718
from typing import List, Optional
1819

1920
from camel.functions import OpenAIFunction
@@ -57,8 +58,8 @@ def __init__(
5758
self.file_path = file_path
5859
self.file_content = file_content
5960

60-
def summary(self) -> str:
61-
r"""Returns a summary of the issue.
61+
def __str__(self) -> str:
62+
r"""Returns a string representation of the issue.
6263
6364
Returns:
6465
str: A string containing the title, body, number, file path, and
@@ -73,6 +74,48 @@ def summary(self) -> str:
7374
)
7475

7576

77+
@dataclass
78+
class GithubPullRequestDiff:
79+
r"""Represents a single diff of a pull request on Github.
80+
81+
Attributes:
82+
filename (str): The name of the file that was changed.
83+
patch (str): The diff patch for the file.
84+
"""
85+
86+
filename: str
87+
patch: str
88+
89+
def __str__(self) -> str:
90+
r"""Returns a string representation of this diff."""
91+
return f"Filename: {self.filename}\nPatch: {self.patch}"
92+
93+
94+
@dataclass
95+
class GithubPullRequest:
96+
r"""Represents a pull request on Github.
97+
98+
Attributes:
99+
title (str): The title of the GitHub pull request.
100+
body (str): The body/content of the GitHub pull request.
101+
diffs (List[GithubPullRequestDiff]): A list of diffs for the pull
102+
request.
103+
"""
104+
105+
title: str
106+
body: str
107+
diffs: List[GithubPullRequestDiff]
108+
109+
def __str__(self) -> str:
110+
r"""Returns a string representation of the pull request."""
111+
diff_summaries = '\n'.join(str(diff) for diff in self.diffs)
112+
return (
113+
f"Title: {self.title}\n"
114+
f"Body: {self.body}\n"
115+
f"Diffs: {diff_summaries}\n"
116+
)
117+
118+
76119
class GithubToolkit(BaseToolkit):
77120
r"""A class representing a toolkit for interacting with GitHub
78121
repositories.
@@ -106,7 +149,7 @@ def __init__(
106149
except ImportError:
107150
raise ImportError(
108151
"Please install `github` first. You can install it by running "
109-
"`pip install wikipedia`."
152+
"`pip install pygithub`."
110153
)
111154
self.github = Github(auth=Auth.Token(access_token))
112155
self.repo = self.github.get_repo(repo_name)
@@ -123,6 +166,7 @@ def get_tools(self) -> List[OpenAIFunction]:
123166
OpenAIFunction(self.retrieve_issue_list),
124167
OpenAIFunction(self.retrieve_issue),
125168
OpenAIFunction(self.create_pull_request),
169+
OpenAIFunction(self.retrieve_pull_requests),
126170
]
127171

128172
def get_github_access_token(self) -> str:
@@ -181,9 +225,47 @@ def retrieve_issue(self, issue_number: int) -> Optional[str]:
181225
issues = self.retrieve_issue_list()
182226
for issue in issues:
183227
if issue.number == issue_number:
184-
return issue.summary()
228+
return str(issue)
185229
return None
186230

231+
def retrieve_pull_requests(
232+
self, days: int, state: str, max_prs: int
233+
) -> List[str]:
234+
r"""Retrieves a summary of merged pull requests from the repository.
235+
The summary will be provided for the last specified number of days.
236+
237+
Args:
238+
days (int): The number of days to retrieve merged pull requests
239+
for.
240+
state (str): A specific state of PRs to retrieve. Can be open or
241+
closed.
242+
max_prs (int): The maximum number of PRs to retrieve.
243+
244+
Returns:
245+
List[str]: A list of merged pull request summaries.
246+
"""
247+
pull_requests = self.repo.get_pulls(state=state)
248+
merged_prs = []
249+
earliest_date: datetime = datetime.utcnow() - timedelta(days=days)
250+
251+
for pr in pull_requests[:max_prs]:
252+
if (
253+
pr.merged
254+
and pr.merged_at is not None
255+
and pr.merged_at.timestamp() > earliest_date.timestamp()
256+
):
257+
pr_details = GithubPullRequest(pr.title, pr.body, [])
258+
259+
# Get files changed in the PR
260+
files = pr.get_files()
261+
262+
for file in files:
263+
diff = GithubPullRequestDiff(file.filename, file.patch)
264+
pr_details.diffs.append(diff)
265+
266+
merged_prs.append(str(pr_details))
267+
return merged_prs
268+
187269
def create_pull_request(
188270
self,
189271
file_path: str,

examples/function_call/github_examples.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,74 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import argparse
15+
1416
from colorama import Fore
1517

1618
from camel.agents import ChatAgent
1719
from camel.configs import ChatGPTConfig
20+
from camel.functions import OpenAIFunction
1821
from camel.messages import BaseMessage
1922
from camel.toolkits import GithubToolkit
2023
from camel.utils import print_text_animated
2124

2225

26+
def write_weekly_pr_summary(repo_name, model=None):
27+
prompt = """
28+
You need to write a summary of the pull requests that were merged in the
29+
last week.
30+
You can use the provided github function retrieve_pull_requests to
31+
retrieve the list of pull requests that were merged in the last week.
32+
The maximum amount of PRs to analyze is 3.
33+
You have to pass the number of days as the first parameter to
34+
retrieve_pull_requests and state='closed' as the second parameter.
35+
The function will return a list of pull requests with the following
36+
properties: title, body, and diffs.
37+
Diffs is a list of dictionaries with the following properties: filename,
38+
diff.
39+
You will have to look closely at each diff to understand the changes that
40+
were made in each pull request.
41+
Output a twitter post that describes recent changes in the project and
42+
thanks the contributors.
43+
44+
Here is an example of a summary for one pull request:
45+
📢 We've improved function calling in the 🐪 CAMEL-AI framework!
46+
This update enhances the handling of various docstring styles and supports
47+
enum types, ensuring more accurate and reliable function calls.
48+
Thanks to our contributor Jiahui Zhang for making this possible.
49+
"""
50+
print(Fore.YELLOW + f"Final prompt:\n{prompt}\n")
51+
52+
toolkit = GithubToolkit(repo_name=repo_name)
53+
assistant_sys_msg = BaseMessage.make_assistant_message(
54+
role_name="Marketing Manager",
55+
content=f"""
56+
You are an experienced marketing manager responsible for posting
57+
weekly updates about the status
58+
of an open source project {repo_name} on the project's blog.
59+
""",
60+
)
61+
assistant_model_config = ChatGPTConfig(
62+
tools=[OpenAIFunction(toolkit.retrieve_pull_requests)],
63+
temperature=0.0,
64+
)
65+
agent = ChatAgent(
66+
assistant_sys_msg,
67+
model_type=model,
68+
model_config=assistant_model_config,
69+
tools=[OpenAIFunction(toolkit.retrieve_pull_requests)],
70+
)
71+
agent.reset()
72+
73+
user_msg = BaseMessage.make_user_message(role_name="User", content=prompt)
74+
assistant_response = agent.step(user_msg)
75+
76+
if len(assistant_response.msgs) > 0:
77+
print_text_animated(
78+
Fore.GREEN + f"Agent response:\n{assistant_response.msg.content}\n"
79+
)
80+
81+
2382
def solve_issue(
2483
repo_name,
2584
issue_number,
@@ -70,8 +129,12 @@ def solve_issue(
70129

71130

72131
def main(model=None) -> None:
73-
repo_name = "camel-ai/test-github-agent"
74-
solve_issue(repo_name=repo_name, issue_number=1, model=model)
132+
parser = argparse.ArgumentParser(description='Enter repo name.')
133+
parser.add_argument('repo_name', type=str, help='Name of the repository')
134+
args = parser.parse_args()
135+
136+
repo_name = args.repo_name
137+
write_weekly_pr_summary(repo_name=repo_name, model=model)
75138

76139

77140
if __name__ == "__main__":

test/toolkits/test_github_toolkit.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414

15+
from datetime import datetime
1516
from unittest.mock import MagicMock, patch
1617

1718
from github import Auth, Github
1819
from github.ContentFile import ContentFile
1920

20-
from camel.toolkits.github_toolkit import GithubIssue, GithubToolkit
21+
from camel.toolkits.github_toolkit import (
22+
GithubIssue,
23+
GithubPullRequest,
24+
GithubPullRequestDiff,
25+
GithubToolkit,
26+
)
2127

2228

2329
@patch.object(Github, '__init__', lambda self, *args, **kwargs: None)
@@ -120,9 +126,9 @@ def test_retrieve_issue(monkeypatch):
120126
file_path="path/to/file",
121127
file_content="This is the content of the file",
122128
)
123-
assert (
124-
issue == expected_issue.summary()
125-
), f"Expected {expected_issue.summary()}, but got {issue}"
129+
assert issue == str(
130+
expected_issue
131+
), f"Expected {expected_issue}, but got {issue}"
126132

127133

128134
@patch.object(Github, 'get_repo', return_value=MagicMock())
@@ -165,6 +171,54 @@ def test_create_pull_request(monkeypatch):
165171
), f"Expected {expected_response}, but got {pr}"
166172

167173

174+
@patch.object(Github, 'get_repo', return_value=MagicMock())
175+
@patch.object(Auth.Token, '__init__', lambda self, *args, **kwargs: None)
176+
def test_retrieve_pull_requests(monkeypatch):
177+
# Call the constructor of the GithubToolkit class
178+
github_toolkit = GithubToolkit("repo_name", "token")
179+
180+
# Create a mock file
181+
mock_file = MagicMock()
182+
mock_file.filename = "path/to/file"
183+
mock_file.diff = "This is the diff of the file"
184+
185+
# Create a mock pull request
186+
mock_pull_request = MagicMock()
187+
mock_pull_request.title = "Test PR"
188+
mock_pull_request.body = "This is a test issue"
189+
mock_pull_request.merged_at = datetime.utcnow()
190+
191+
# Create a mock file
192+
mock_file = MagicMock()
193+
mock_file.filename = "path/to/file"
194+
mock_file.patch = "This is the diff of the file"
195+
196+
# Mock the get_files method of the mock_pull_request instance to return a
197+
# list containing the mock file object
198+
mock_pull_request.get_files.return_value = [mock_file]
199+
200+
# Mock the get_issues method of the mock repo instance to return a list
201+
# containing the mock issue object
202+
github_toolkit.repo.get_pulls.return_value = [mock_pull_request]
203+
204+
pull_requests = github_toolkit.retrieve_pull_requests(
205+
days=7, state='closed', max_prs=3
206+
)
207+
# Assert the returned issue list
208+
expected_pull_request = GithubPullRequest(
209+
title="Test PR",
210+
body="This is a test issue",
211+
diffs=[
212+
GithubPullRequestDiff(
213+
filename="path/to/file", patch="This is the diff of the file"
214+
)
215+
],
216+
)
217+
assert pull_requests == [
218+
str(expected_pull_request)
219+
], f"Expected {expected_pull_request}, but got {pull_requests}"
220+
221+
168222
def test_github_issue():
169223
# Create a GithubIssue object
170224
issue = GithubIssue(
@@ -183,7 +237,7 @@ def test_github_issue():
183237
assert issue.file_content == "This is the content of the file"
184238

185239
# Test the summary method
186-
summary = issue.summary()
240+
summary = str(issue)
187241
expected_summary = (
188242
f"Title: {issue.title}\n"
189243
f"Body: {issue.body}\n"

0 commit comments

Comments
 (0)