Skip to content

Commit dedb187

Browse files
authored
Merge pull request #4 from ravilushqa/enhance_description
added complete once if possible, added jira integration
2 parents 76855cf + bab9ada commit dedb187

File tree

4 files changed

+167
-26
lines changed

4 files changed

+167
-26
lines changed

cmd/description/main.go

+82-21
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/sashabaranov/go-openai"
1313

1414
ghClient "github.com/ravilushqa/gpt-pullrequest-updater/github"
15+
"github.com/ravilushqa/gpt-pullrequest-updater/jira"
1516
oAIClient "github.com/ravilushqa/gpt-pullrequest-updater/openai"
1617
)
1718

@@ -22,6 +23,7 @@ var opts struct {
2223
Repo string `long:"repo" env:"REPO" description:"GitHub repo" required:"true"`
2324
PRNumber int `long:"pr-number" env:"PR_NUMBER" description:"Pull request number" required:"true"`
2425
Test bool `long:"test" env:"TEST" description:"Test mode"`
26+
JiraURL string `long:"jira-url" env:"JIRA_URL" description:"Jira URL"`
2527
}
2628

2729
func main() {
@@ -35,6 +37,10 @@ func main() {
3537
os.Exit(0)
3638
}
3739

40+
if opts.Test {
41+
fmt.Println("Test mode")
42+
}
43+
3844
if err := run(ctx); err != nil {
3945
panic(err)
4046
}
@@ -54,55 +60,110 @@ func run(ctx context.Context) error {
5460
return fmt.Errorf("error getting commits: %w", err)
5561
}
5662

57-
var OverallDescribeCompletion string
58-
OverallDescribeCompletion += fmt.Sprintf("Pull request title: %s, body: %s\n\n", pr.GetTitle(), pr.GetBody())
63+
var sumDiffs int
5964
for _, file := range diff.Files {
65+
sumDiffs += len(*file.Patch)
66+
}
6067

68+
var completion string
69+
if sumDiffs < 4000 {
70+
completion, err = genCompletionOnce(ctx, openAIClient, diff)
71+
if err != nil {
72+
return fmt.Errorf("error generating completition once: %w", err)
73+
}
74+
} else {
75+
completion, err = genCompletionPerFile(ctx, openAIClient, diff, pr)
76+
if err != nil {
77+
return fmt.Errorf("error generating completition twice: %w", err)
78+
}
79+
}
80+
81+
if opts.JiraURL != "" {
82+
fmt.Println("Adding Jira ticket")
83+
id, err := jira.ExtractJiraTicketID(*pr.Title)
84+
if err != nil {
85+
fmt.Printf("Error extracting Jira ticket ID: %v \n", err)
86+
} else {
87+
completion = fmt.Sprintf("### JIRA ticket: [%s](%s) \n\n%s", id, jira.GenerateJiraTicketURL(opts.JiraURL, id), completion)
88+
}
89+
}
90+
91+
if opts.Test {
92+
fmt.Println(completion)
93+
return nil
94+
}
95+
96+
// Update the pull request description
97+
fmt.Println("Updating pull request")
98+
updatePr := &github.PullRequest{Body: github.String(completion)}
99+
if _, err = githubClient.UpdatePullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber, updatePr); err != nil {
100+
return fmt.Errorf("error updating pull request: %w", err)
101+
}
102+
103+
return nil
104+
}
105+
106+
func genCompletionOnce(ctx context.Context, client *oAIClient.Client, diff *github.CommitsComparison) (string, error) {
107+
fmt.Println("Generating completion once")
108+
messages := make([]openai.ChatCompletionMessage, 0, len(diff.Files))
109+
messages = append(messages, openai.ChatCompletionMessage{
110+
Role: openai.ChatMessageRoleUser,
111+
Content: oAIClient.PromptDescribeChanges,
112+
})
113+
for _, file := range diff.Files {
61114
if file.Patch == nil {
62115
continue
63116
}
64117

118+
messages = append(messages, openai.ChatCompletionMessage{
119+
Role: openai.ChatMessageRoleUser,
120+
Content: *file.Patch,
121+
})
122+
}
123+
124+
fmt.Println("Sending prompt to OpenAI")
125+
completion, err := client.ChatCompletion(ctx, messages)
126+
if err != nil {
127+
return "", fmt.Errorf("error completing prompt: %w", err)
128+
}
129+
130+
return completion, nil
131+
}
132+
133+
func genCompletionPerFile(ctx context.Context, client *oAIClient.Client, diff *github.CommitsComparison, pr *github.PullRequest) (string, error) {
134+
fmt.Println("Generating completion per file")
135+
OverallDescribeCompletion := fmt.Sprintf("Pull request title: %s, body: %s\n\n", pr.GetTitle(), pr.GetBody())
136+
137+
for i, file := range diff.Files {
65138
prompt := fmt.Sprintf(oAIClient.PromptDescribeChanges, *file.Patch)
66139

67140
if len(prompt) > 4096 {
68141
prompt = fmt.Sprintf("%s...", prompt[:4093])
69142
}
70143

71-
completion, err := openAIClient.ChatCompletion(ctx, []openai.ChatCompletionMessage{
144+
fmt.Printf("Sending prompt to OpenAI for file %d/%d\n", i+1, len(diff.Files))
145+
completion, err := client.ChatCompletion(ctx, []openai.ChatCompletionMessage{
72146
{
73147
Role: openai.ChatMessageRoleUser,
74148
Content: prompt,
75149
},
76150
})
77151
if err != nil {
78-
return fmt.Errorf("error getting review: %w", err)
152+
return "", fmt.Errorf("error getting review: %w", err)
79153
}
80154
OverallDescribeCompletion += fmt.Sprintf("File: %s \nDescription: %s \n\n", file.GetFilename(), completion)
81155
}
82156

83-
overallCompletion, err := openAIClient.ChatCompletion(ctx, []openai.ChatCompletionMessage{
157+
fmt.Println("Sending final prompt to OpenAI")
158+
overallCompletion, err := client.ChatCompletion(ctx, []openai.ChatCompletionMessage{
84159
{
85160
Role: openai.ChatMessageRoleUser,
86161
Content: fmt.Sprintf(oAIClient.PromptOverallDescribe, OverallDescribeCompletion),
87162
},
88163
})
89164
if err != nil {
90-
return fmt.Errorf("error getting overall review: %w", err)
91-
}
92-
93-
if opts.Test {
94-
fmt.Println(OverallDescribeCompletion)
95-
fmt.Println("=====================================")
96-
fmt.Println(overallCompletion)
97-
98-
return nil
165+
return "", fmt.Errorf("error getting overall review: %w", err)
99166
}
100167

101-
// Update the pull request description
102-
updatePr := &github.PullRequest{Body: github.String(overallCompletion)}
103-
if _, err = githubClient.UpdatePullRequest(ctx, opts.Owner, opts.Repo, opts.PRNumber, updatePr); err != nil {
104-
return fmt.Errorf("error updating pull request: %w", err)
105-
}
106-
107-
return nil
168+
return overallCompletion, nil
108169
}

cmd/review/main.go

+1-5
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ func run(ctx context.Context) error {
5757

5858
var OverallReviewCompletion string
5959
for _, file := range diff.Files {
60-
if file.GetStatus() == "removed" || file.GetStatus() == "renamed" {
61-
continue
62-
}
63-
64-
if file.Patch == nil {
60+
if file.Patch == nil || file.GetStatus() == "removed" || file.GetStatus() == "renamed" {
6561
continue
6662
}
6763

jira/jira.go

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package jira
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
)
7+
8+
const ticketUrlFormat = "%s/browse/%s"
9+
10+
// ExtractJiraTicketID returns the first JIRA ticket ID found in the input string.
11+
func ExtractJiraTicketID(s string) (string, error) {
12+
// This regular expression pattern matches a JIRA ticket ID (e.g. PROJ-123).
13+
pattern := `([aA-zZ]+-\d+)`
14+
re, err := regexp.Compile(pattern)
15+
if err != nil {
16+
return "", fmt.Errorf("error compiling regex: %w", err)
17+
}
18+
19+
matches := re.FindStringSubmatch(s)
20+
if len(matches) == 0 {
21+
return "", fmt.Errorf("no JIRA ticket ID found in the input string")
22+
}
23+
24+
return matches[0], nil
25+
}
26+
27+
func GenerateJiraTicketURL(jiraURL, ticketID string) string {
28+
return fmt.Sprintf(ticketUrlFormat, jiraURL, ticketID)
29+
}

jira/jira_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package jira
2+
3+
import "testing"
4+
5+
func TestExtractJiraTicketID(t *testing.T) {
6+
testCases := []struct {
7+
name string
8+
input string
9+
expected string
10+
expectError bool
11+
}{
12+
{
13+
name: "Valid ticket ID",
14+
input: "This is a sample text with a JIRA ticket ID: PROJ-123, let's extract it.",
15+
expected: "PROJ-123",
16+
expectError: false,
17+
},
18+
{
19+
name: "No ticket ID",
20+
input: "This is a sample text without a JIRA ticket ID.",
21+
expectError: true,
22+
},
23+
{
24+
name: "Multiple ticket IDs",
25+
input: "This text has multiple JIRA ticket IDs: PROJ-123, TASK-456, and BUG-789.",
26+
expected: "PROJ-123",
27+
expectError: false,
28+
},
29+
{
30+
name: "Valid ticket ID. Lowercase.",
31+
input: "This is an invalid JIRA ticket ID: Proj-123.",
32+
expected: "Proj-123",
33+
expectError: false,
34+
},
35+
}
36+
37+
for _, tc := range testCases {
38+
t.Run(tc.name, func(t *testing.T) {
39+
result, err := ExtractJiraTicketID(tc.input)
40+
if tc.expectError {
41+
if err == nil {
42+
t.Errorf("expected an error, but got none")
43+
}
44+
} else {
45+
if err != nil {
46+
t.Errorf("unexpected error: %v", err)
47+
}
48+
49+
if result != tc.expected {
50+
t.Errorf("expected result '%s', but got '%s'", tc.expected, result)
51+
}
52+
}
53+
})
54+
}
55+
}

0 commit comments

Comments
 (0)