From e0cde718c1f9e4f303fde05b8992382824a3e1c9 Mon Sep 17 00:00:00 2001 From: Martina Jireckova Date: Wed, 16 Apr 2025 08:21:45 +0000 Subject: [PATCH 1/4] WIP: List notifications tool --- pkg/github/notifications.go | 92 ++++++++++++++++++++++ pkg/github/notifications_test.go | 127 +++++++++++++++++++++++++++++++ pkg/github/server.go | 14 ++++ pkg/github/tools.go | 6 ++ 4 files changed, 239 insertions(+) create mode 100644 pkg/github/notifications.go create mode 100644 pkg/github/notifications_test.go diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go new file mode 100644 index 00000000..eb664b9b --- /dev/null +++ b/pkg/github/notifications.go @@ -0,0 +1,92 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListNotifications creates a tool to list notifications for a GitHub user. +func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_notifications", + mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "List notifications for a GitHub user")), + mcp.WithNumber("page", + mcp.Description("Page number"), + ), + mcp.WithNumber("per_page", + mcp.Description("Number of records per page"), + ), + mcp.WithBoolean("all", + mcp.Description("Whether to fetch all notifications, including read ones"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + page, err := OptionalIntParamWithDefault(request, "page", 1) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + perPage, err := OptionalIntParamWithDefault(request, "per_page", 30) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + all, err := OptionalBoolParamWithDefault(request, "all", false) // Default to false unless specified + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if request.Params.Arguments["all"] == true { + all = true // Set to true if user explicitly asks for all notifications + } + + opts := &github.NotificationListOptions{ + ListOptions: github.ListOptions{ + Page: page, + PerPage: perPage, + }, + All: all, // Include all notifications, even those already read. + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + notifications, resp, err := client.Activity.ListNotifications(ctx, opts) + if err != nil { + return nil, fmt.Errorf("failed to list notifications: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list notifications: %s", string(body))), nil + } + + // Extract the notification title in addition to reason, url, and timestamp. + var extractedNotifications []map[string]interface{} + for _, notification := range notifications { + extractedNotifications = append(extractedNotifications, map[string]interface{}{ + "title": notification.GetSubject().GetTitle(), + "reason": notification.GetReason(), + "url": notification.GetURL(), + "timestamp": notification.GetUpdatedAt(), + }) + } + + r, err := json.Marshal(extractedNotifications) + if err != nil { + return nil, fmt.Errorf("failed to marshal notifications: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go new file mode 100644 index 00000000..2d663c7a --- /dev/null +++ b/pkg/github/notifications_test.go @@ -0,0 +1,127 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListNotifications(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_notifications", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.Contains(t, tool.InputSchema.Properties, "per_page") + assert.Contains(t, tool.InputSchema.Properties, "all") + + // Setup mock notifications + mockNotifications := []*github.Notification{ + { + ID: github.String("1"), + Reason: github.String("mention"), + Subject: &github.NotificationSubject{ + Title: github.String("Test Notification 1"), + }, + UpdatedAt: &github.Timestamp{Time: time.Now()}, + URL: github.String("https://example.com/notifications/threads/1"), + }, + { + ID: github.String("2"), + Reason: github.String("team_mention"), + Subject: &github.NotificationSubject{ + Title: github.String("Test Notification 2"), + }, + UpdatedAt: &github.Timestamp{Time: time.Now()}, + URL: github.String("https://example.com/notifications/threads/1"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedResponse []*github.Notification + expectedErrMsg string + }{ + { + name: "list all notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + mockNotifications, + ), + ), + requestArgs: map[string]interface{}{ + "all": true, + }, + expectError: false, + expectedResponse: mockNotifications, + }, + { + name: "list unread notifications", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetNotifications, + mockNotifications[:1], // Only the first notification + ), + ), + requestArgs: map[string]interface{}{ + "all": false, + }, + expectError: false, + expectedResponse: mockNotifications[:1], + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedNotifications []*github.Notification + err = json.Unmarshal([]byte(textContent.Text), &returnedNotifications) + require.NoError(t, err) + assert.Equal(t, len(tc.expectedResponse), len(returnedNotifications)) + for i, notification := range returnedNotifications { + // Ensure all required fields are mocked + assert.NotNil(t, notification.Subject, "Subject should not be nil") + assert.NotNil(t, notification.Subject.Title, "Title should not be nil") + assert.NotNil(t, notification.Reason, "Reason should not be nil") + assert.NotNil(t, notification.URL, "URL should not be nil") + assert.NotNil(t, notification.UpdatedAt, "UpdatedAt should not be nil") + // assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID) + assert.Equal(t, *tc.expectedResponse[i].Reason, *notification.Reason) + // assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index e4c24171..b2413f54 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -130,6 +130,20 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// OptionalBoolParamWithDefault is a helper function that retrieves a boolean parameter from the request. +// If the parameter is not present, it returns the provided default value. If the parameter is present, +// it validates its type and returns the value. +func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) (bool, error) { + v, err := OptionalParam[bool](request, s) + if err != nil { + return false, err + } + if b == false { + return b, nil + } + return v, nil +} + // OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request // similar to optionalIntParam, but it also takes a default value. func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ce10c4ad..4ac5d3bc 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -76,6 +76,11 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") + notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). + AddReadTools( + toolsets.NewServerTool(ListNotifications(getClient, t)), + ) + // Add toolsets to the group tsg.AddToolset(repos) tsg.AddToolset(issues) @@ -83,6 +88,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(pullRequests) tsg.AddToolset(codeSecurity) tsg.AddToolset(experiments) + tsg.AddToolset(notifications) // Enable the requested features if err := tsg.EnableToolsets(passedToolsets); err != nil { From 2e52386ba5c8d9234de80c6caf6610b3d41bf8ef Mon Sep 17 00:00:00 2001 From: Martina Jireckova Date: Tue, 22 Apr 2025 13:01:04 +0000 Subject: [PATCH 2/4] Improve testing and mapping code in request and response --- README.md | 8 ++++++++ pkg/github/notifications.go | 17 +---------------- pkg/github/notifications_test.go | 26 ++++++++++---------------- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6bfc6ab5..4d04f6fc 100644 --- a/README.md +++ b/README.md @@ -437,6 +437,14 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `state`: Alert state (string, optional) - `severity`: Alert severity (string, optional) +### Notifications + +- **list_notifications** - List notifications for a GitHub user + + - `page`: Page number (number, optional, default: 1) + - `per_page`: Number of records per page (number, optional, default: 30) + - `all`: Whether to fetch all notifications, including read ones (boolean, optional, default: false) + ## Resources ### Repository Content diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index eb664b9b..10a98be2 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -41,10 +41,6 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - if request.Params.Arguments["all"] == true { - all = true // Set to true if user explicitly asks for all notifications - } - opts := &github.NotificationListOptions{ ListOptions: github.ListOptions{ Page: page, @@ -71,18 +67,7 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(fmt.Sprintf("failed to list notifications: %s", string(body))), nil } - // Extract the notification title in addition to reason, url, and timestamp. - var extractedNotifications []map[string]interface{} - for _, notification := range notifications { - extractedNotifications = append(extractedNotifications, map[string]interface{}{ - "title": notification.GetSubject().GetTitle(), - "reason": notification.GetReason(), - "url": notification.GetURL(), - "timestamp": notification.GetUpdatedAt(), - }) - } - - r, err := json.Marshal(extractedNotifications) + r, err := json.Marshal(notifications) if err != nil { return nil, fmt.Errorf("failed to marshal notifications: %w", err) } diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 2d663c7a..20c7967b 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -28,22 +28,22 @@ func Test_ListNotifications(t *testing.T) { // Setup mock notifications mockNotifications := []*github.Notification{ { - ID: github.String("1"), - Reason: github.String("mention"), + ID: github.Ptr("1"), + Reason: github.Ptr("mention"), Subject: &github.NotificationSubject{ - Title: github.String("Test Notification 1"), + Title: github.Ptr("Test Notification 1"), }, UpdatedAt: &github.Timestamp{Time: time.Now()}, - URL: github.String("https://example.com/notifications/threads/1"), + URL: github.Ptr("https://example.com/notifications/threads/1"), }, { - ID: github.String("2"), - Reason: github.String("team_mention"), + ID: github.Ptr("2"), + Reason: github.Ptr("team_mention"), Subject: &github.NotificationSubject{ - Title: github.String("Test Notification 2"), + Title: github.Ptr("Test Notification 2"), }, UpdatedAt: &github.Timestamp{Time: time.Now()}, - URL: github.String("https://example.com/notifications/threads/1"), + URL: github.Ptr("https://example.com/notifications/threads/1"), }, } @@ -112,15 +112,9 @@ func Test_ListNotifications(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(tc.expectedResponse), len(returnedNotifications)) for i, notification := range returnedNotifications { - // Ensure all required fields are mocked - assert.NotNil(t, notification.Subject, "Subject should not be nil") - assert.NotNil(t, notification.Subject.Title, "Title should not be nil") - assert.NotNil(t, notification.Reason, "Reason should not be nil") - assert.NotNil(t, notification.URL, "URL should not be nil") - assert.NotNil(t, notification.UpdatedAt, "UpdatedAt should not be nil") - // assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID) + assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID) assert.Equal(t, *tc.expectedResponse[i].Reason, *notification.Reason) - // assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title) + assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title) } }) } From ea562ae82a21128777ee70e3d8e06f76d7836505 Mon Sep 17 00:00:00 2001 From: Martina Jireckova Date: Tue, 22 Apr 2025 13:31:28 +0000 Subject: [PATCH 3/4] Fix lint --- pkg/github/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/server.go b/pkg/github/server.go index b2413f54..91328499 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -138,7 +138,7 @@ func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) if err != nil { return false, err } - if b == false { + if !b { return b, nil } return v, nil From 81f093e7be6c95f38530f7cf36188f425b4a33ca Mon Sep 17 00:00:00 2001 From: Martina Jireckova Date: Tue, 22 Apr 2025 15:36:43 +0000 Subject: [PATCH 4/4] Inline optional bool function for all notifications --- pkg/github/notifications.go | 6 +++--- pkg/github/server.go | 14 -------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 10a98be2..512452d2 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -36,9 +36,9 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return mcp.NewToolResultError(err.Error()), nil } - all, err := OptionalBoolParamWithDefault(request, "all", false) // Default to false unless specified - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + all := false + if val, err := OptionalParam[bool](request, "all"); err == nil { + all = val } opts := &github.NotificationListOptions{ diff --git a/pkg/github/server.go b/pkg/github/server.go index 91328499..e4c24171 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -130,20 +130,6 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } -// OptionalBoolParamWithDefault is a helper function that retrieves a boolean parameter from the request. -// If the parameter is not present, it returns the provided default value. If the parameter is present, -// it validates its type and returns the value. -func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) (bool, error) { - v, err := OptionalParam[bool](request, s) - if err != nil { - return false, err - } - if !b { - return b, nil - } - return v, nil -} - // OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request // similar to optionalIntParam, but it also takes a default value. func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {