From 36e6add0ddb84063f419e45ac349fb49ff606087 Mon Sep 17 00:00:00 2001 From: Adam Holt Date: Thu, 12 Jun 2025 11:29:08 +0200 Subject: [PATCH 1/3] Return concrete error types for API errors --- pkg/errors/error.go | 41 +++++++++++ pkg/github/code_scanning.go | 13 +++- pkg/github/context_tools.go | 9 ++- pkg/github/issues.go | 58 +++++++++++++--- pkg/github/notifications.go | 37 ++++++++-- pkg/github/pullrequests.go | 124 +++++++++++++++++++++++++++------ pkg/github/repositories.go | 127 ++++++++++++++++++++++++++++------ pkg/github/search.go | 19 ++++- pkg/github/secret_scanning.go | 13 +++- 9 files changed, 373 insertions(+), 68 deletions(-) create mode 100644 pkg/errors/error.go diff --git a/pkg/errors/error.go b/pkg/errors/error.go new file mode 100644 index 00000000..d5d93a28 --- /dev/null +++ b/pkg/errors/error.go @@ -0,0 +1,41 @@ +package errors + +import ( + "fmt" + + "github.com/google/go-github/v72/github" +) + +type GitHubAPIError struct { + Message string `json:"message"` + Response *github.Response `json:"-"` + Err error `json:"-"` +} + +func NewGitHubAPIError(message string, resp *github.Response, err error) *GitHubAPIError { + return &GitHubAPIError{ + Message: message, + Response: resp, + Err: err, + } +} + +func (e *GitHubAPIError) Error() string { + return fmt.Errorf("%s: %w", e.Message, e.Err).Error() +} + +type GitHubGraphQLError struct { + Message string `json:"message"` + Err error `json:"-"` +} + +func NewGitHubGraphQLError(message string, err error) *GitHubGraphQLError { + return &GitHubGraphQLError{ + Message: message, + Err: err, + } +} + +func (e *GitHubGraphQLError) Error() string { + return fmt.Errorf("%s: %w", e.Message, e.Err).Error() +} diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 98714b6c..e2110d3d 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -54,7 +55,11 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, fmt.Errorf("failed to get alert: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get alert", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -138,7 +143,11 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) if err != nil { - return nil, fmt.Errorf("failed to list alerts: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list alerts", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 62a953de..b35a5c14 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -3,6 +3,7 @@ package github import ( "context" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -28,9 +29,13 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too return mcp.NewToolResultErrorFromErr("failed to get GitHub client", err), nil } - user, _, err := client.Users.Get(ctx, "") + user, res, err := client.Users.Get(ctx, "") if err != nil { - return mcp.NewToolResultErrorFromErr("failed to get user", err), nil + return nil, ghErrors.NewGitHubAPIError( + "failed to get user", + res, + err, + ) } return MarshalledTextResult(user), nil diff --git a/pkg/github/issues.go b/pkg/github/issues.go index b4c64c8d..d513fa45 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -9,6 +9,7 @@ import ( "strings" "time" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v72/github" @@ -58,7 +59,11 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { - return nil, fmt.Errorf("failed to get issue: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to get issue with number '%d'", issueNumber), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -132,7 +137,11 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc } createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { - return nil, fmt.Errorf("failed to create comment: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to create comment on issue '%d'", issueNumber), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -220,7 +229,11 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( } result, resp, err := client.Search.Issues(ctx, query, opts) if err != nil { - return nil, fmt.Errorf("failed to search issues: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to search issues", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -342,7 +355,11 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t } issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) if err != nil { - return nil, fmt.Errorf("failed to create issue: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create issue", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -464,7 +481,11 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to } issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list issues: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list issues", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -610,7 +631,11 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t } updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) if err != nil { - return nil, fmt.Errorf("failed to update issue: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to update issue", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -693,7 +718,11 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun } comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) if err != nil { - return nil, fmt.Errorf("failed to get issue comments: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get issue comments", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -824,7 +853,10 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio var query suggestedActorsQuery err := client.Query(ctx, &query, variables) if err != nil { - return nil, err + return nil, ghErrors.NewGitHubGraphQLError( + "failed to list suggested actors", + err, + ) } // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the @@ -870,7 +902,10 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio } if err := client.Query(ctx, &getIssueQuery, variables); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get issue ID", + err, + ) } // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already @@ -896,7 +931,10 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio }, nil, ); err != nil { - return nil, fmt.Errorf("failed to replace actors for assignable: %w", err) + return nil, ghErrors.NewGitHubGraphQLError( + "failed to replace actors for assignable", + err, + ) } return mcp.NewToolResultText("successfully assigned copilot to issue"), nil diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 677ee99f..9b81878f 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -118,7 +119,11 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu notifications, resp, err = client.Activity.ListNotifications(ctx, opts) } if err != nil { - return nil, fmt.Errorf("failed to get notifications: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list notifications", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -187,7 +192,11 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper } if err != nil { - return nil, fmt.Errorf("failed to mark notification as %s: %w", state, err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to mark notification as %s", state), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -262,7 +271,11 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) } if err != nil { - return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to mark all notifications as read", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -304,7 +317,11 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel thread, resp, err := client.Activity.GetThread(ctx, notificationID) if err != nil { - return nil, fmt.Errorf("failed to get notification details: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -385,7 +402,11 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl } if apiErr != nil { - return nil, fmt.Errorf("failed to %s notification subscription: %w", action, apiErr) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to %s notification subscription", action), + resp, + apiErr, + ) } defer func() { _ = resp.Body.Close() }() @@ -474,7 +495,11 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati } if apiErr != nil { - return nil, fmt.Errorf("failed to %s repository subscription: %w", action, apiErr) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to %s repository subscription", action), + resp, + apiErr, + ) } if resp != nil { defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index b16920aa..89a3c1bf 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -13,6 +13,7 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/shurcooL/githubv4" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" ) @@ -57,7 +58,11 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -172,7 +177,11 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { - return nil, fmt.Errorf("failed to create pull request: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create pull request", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -293,7 +302,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to update pull request", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -402,7 +415,11 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list pull requests: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list pull requests", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -491,7 +508,11 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { - return nil, fmt.Errorf("failed to merge pull request: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to merge pull request", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -554,7 +575,11 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, fmt.Errorf("failed to get pull request files: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request files", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -616,7 +641,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -631,7 +660,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe // Get combined status for the head SHA status, resp, err := client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return nil, fmt.Errorf("failed to get combined status: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get combined status", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -709,7 +742,11 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Pull request branch update is in progress"), nil } - return nil, fmt.Errorf("failed to update pull request branch: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to update pull request branch", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -777,7 +814,11 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, fmt.Errorf("failed to get pull request comments: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request comments", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -839,7 +880,11 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { - return nil, fmt.Errorf("failed to get pull request reviews: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request reviews", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -926,7 +971,10 @@ func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translation "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get pull request", + err, + ) } // Now we have the GQL ID, we can create a review @@ -1017,7 +1065,10 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get pull request", + err, + ) } // Now we have the GQL ID, we can create a pending review @@ -1135,7 +1186,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get current user", + err, + ) } var getLatestReviewForViewerQuery struct { @@ -1160,7 +1214,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get latest review for current user", + err, + ) } // Validate there is one review and the state is pending @@ -1266,7 +1323,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get current user", + err, + ) } var getLatestReviewForViewerQuery struct { @@ -1291,7 +1351,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get latest review for current user", + err, + ) } // Validate there is one review and the state is pending @@ -1324,7 +1387,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. }, nil, ); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to submit pull request review", + err, + ) } // Return nothing interesting, just indicate success for the time being. @@ -1381,7 +1447,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get current user", + err, + ) } var getLatestReviewForViewerQuery struct { @@ -1406,7 +1475,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubGraphQLError( + "failed to get latest review for current user", + err, + ) } // Validate there is one review and the state is pending @@ -1490,7 +1562,11 @@ func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperF github.RawOptions{Type: github.Diff}, ) if err != nil { - return mcp.NewToolResultError(err.Error()), nil + return nil, ghErrors.NewGitHubAPIError( + "failed to get pull request diff", + resp, + err, + ) } if resp.StatusCode != http.StatusOK { @@ -1563,7 +1639,11 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe }, ) if err != nil { - return nil, fmt.Errorf("failed to request copilot review: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to request copilot review", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index eafb71ac..52b2b4cc 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" @@ -68,7 +69,11 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too } commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) if err != nil { - return nil, fmt.Errorf("failed to get commit: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to get commit: %s", sha), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -150,7 +155,11 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list commits: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to list commits: %s", sha), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -217,7 +226,11 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list branches: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list branches", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -326,7 +339,11 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { - return nil, fmt.Errorf("failed to create/update file: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create/update file", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -400,7 +417,11 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { - return nil, fmt.Errorf("failed to create repository: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create repository", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -641,7 +662,11 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Fork is in progress"), nil } - return nil, fmt.Errorf("failed to fork repository: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to fork repository", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -734,7 +759,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get base commit: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get base commit", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -759,7 +788,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Create a new tree with the deletion newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) if err != nil { - return nil, fmt.Errorf("failed to create tree: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create tree", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -779,7 +812,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, fmt.Errorf("failed to create commit: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create commit", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -795,7 +832,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to ref.Object.SHA = newCommit.SHA _, resp, err = client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, fmt.Errorf("failed to update reference: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to update reference", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -876,7 +917,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get default branch if from_branch not specified repository, resp, err := client.Repositories.Get(ctx, owner, repo) if err != nil { - return nil, fmt.Errorf("failed to get repository: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get repository", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -886,7 +931,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get SHA of source branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) if err != nil { - return nil, fmt.Errorf("failed to get reference: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get reference", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -898,7 +947,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) if err != nil { - return nil, fmt.Errorf("failed to create branch: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create branch", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -988,14 +1041,22 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Get the reference for the branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) if err != nil { - return nil, fmt.Errorf("failed to get branch reference: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get branch reference", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get base commit: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get base commit", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1030,7 +1091,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Create a new tree with the file entries newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) if err != nil { - return nil, fmt.Errorf("failed to create tree: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create tree", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1042,7 +1107,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, fmt.Errorf("failed to create commit: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to create commit", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1050,7 +1119,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too ref.Object.SHA = newCommit.SHA updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, fmt.Errorf("failed to update reference: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to update reference", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1107,7 +1180,11 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list tags: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to list tags", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1171,7 +1248,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // First get the tag reference ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) if err != nil { - return nil, fmt.Errorf("failed to get tag reference: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get tag reference", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -1186,7 +1267,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // Then get the tag object tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get tag object: %w", err) + return nil, ghErrors.NewGitHubAPIError( + "failed to get tag object", + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/search.go b/pkg/github/search.go index 157675c1..d10dfffc 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -49,7 +50,11 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } result, resp, err := client.Search.Repositories(ctx, query, opts) if err != nil { - return nil, fmt.Errorf("failed to search repositories: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -125,7 +130,11 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to result, resp, err := client.Search.Code(ctx, query, opts) if err != nil { - return nil, fmt.Errorf("failed to search code: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -215,7 +224,11 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t result, resp, err := client.Search.Users(ctx, "type:user "+query, opts) if err != nil { - return nil, fmt.Errorf("failed to search users: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to search users with query '%s'", query), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index ec0eb15a..ef7901d4 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -55,7 +56,11 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, fmt.Errorf("failed to get alert: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() @@ -132,7 +137,11 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH } alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) if err != nil { - return nil, fmt.Errorf("failed to list alerts: %w", err) + return nil, ghErrors.NewGitHubAPIError( + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ) } defer func() { _ = resp.Body.Close() }() From f90ff16f2e3708bf0f23034dcbb88ff55caad922 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 24 Jun 2025 14:12:51 +0200 Subject: [PATCH 2/3] move to new approach and update testing --- docs/error-handling.md | 125 ++++++++++ internal/ghmcp/server.go | 11 +- pkg/errors/error.go | 88 ++++++- pkg/errors/error_test.go | 379 +++++++++++++++++++++++++++++ pkg/github/actions.go | 50 ++-- pkg/github/code_scanning.go | 8 +- pkg/github/code_scanning_test.go | 14 +- pkg/github/context_tools.go | 4 +- pkg/github/issues.go | 58 +---- pkg/github/notifications.go | 24 +- pkg/github/notifications_test.go | 21 +- pkg/github/pullrequests.go | 88 +++---- pkg/github/pullrequests_test.go | 80 ++++-- pkg/github/repositories.go | 84 +++---- pkg/github/repositories_test.go | 68 ++++-- pkg/github/search.go | 12 +- pkg/github/search_test.go | 21 +- pkg/github/secret_scanning.go | 8 +- pkg/github/secret_scanning_test.go | 14 +- 19 files changed, 904 insertions(+), 253 deletions(-) create mode 100644 docs/error-handling.md create mode 100644 pkg/errors/error_test.go diff --git a/docs/error-handling.md b/docs/error-handling.md new file mode 100644 index 00000000..9bb27e0f --- /dev/null +++ b/docs/error-handling.md @@ -0,0 +1,125 @@ +# Error Handling + +This document describes the error handling patterns used in the GitHub MCP Server, specifically how we handle GitHub API errors and avoid direct use of mcp-go error types. + +## Overview + +The GitHub MCP Server implements a custom error handling approach that serves two primary purposes: + +1. **Tool Response Generation**: Return appropriate MCP tool error responses to clients +2. **Middleware Inspection**: Store detailed error information in the request context for middleware analysis + +This dual approach enables better observability and debugging capabilities, particularly for remote server deployments where understanding the nature of failures (rate limiting, authentication, 404s, 500s, etc.) is crucial for validation and monitoring. + +## Error Types + +### GitHubAPIError + +Used for REST API errors from the GitHub API: + +```go +type GitHubAPIError struct { + Message string `json:"message"` + Response *github.Response `json:"-"` + Err error `json:"-"` +} +``` + +### GitHubGraphQLError + +Used for GraphQL API errors from the GitHub API: + +```go +type GitHubGraphQLError struct { + Message string `json:"message"` + Err error `json:"-"` +} +``` + +## Usage Patterns + +### For GitHub REST API Errors + +Instead of directly returning `mcp.NewToolResultError()`, use: + +```go +return ghErrors.NewGitHubAPIErrorResponse(ctx, message, response, err), nil +``` + +This function: +- Creates a `GitHubAPIError` with the provided message, response, and error +- Stores the error in the context for middleware inspection +- Returns an appropriate MCP tool error response + +### For GitHub GraphQL API Errors + +```go +return ghErrors.NewGitHubGraphQLErrorResponse(ctx, message, err), nil +``` + +### Context Management + +The error handling system uses context to store errors for later inspection: + +```go +// Initialize context with error tracking +ctx = errors.ContextWithGitHubErrors(ctx) + +// Retrieve errors for inspection (typically in middleware) +apiErrors, err := errors.GetGitHubAPIErrors(ctx) +graphqlErrors, err := errors.GetGitHubGraphQLErrors(ctx) +``` + +## Design Principles + +### User-Actionable vs. Developer Errors + +- **User-actionable errors** (authentication failures, rate limits, 404s) should be returned as failed tool calls using the error response functions +- **Developer errors** (JSON marshaling failures, internal logic errors) should be returned as actual Go errors that bubble up through the MCP framework + +### Context Limitations + +This approach was designed to work around current limitations in mcp-go where context is not propagated through each step of request processing. By storing errors in context values, middleware can inspect them without requiring context propagation. + +### Graceful Error Handling + +Error storage operations in context are designed to fail gracefully - if context storage fails, the tool will still return an appropriate error response to the client. + +## Benefits + +1. **Observability**: Middleware can inspect the specific types of GitHub API errors occurring +2. **Debugging**: Detailed error information is preserved without exposing potentially sensitive data in logs +3. **Validation**: Remote servers can use error types and HTTP status codes to validate that changes don't break functionality +4. **Privacy**: Error inspection can be done programmatically using `errors.Is` checks without logging PII + +## Example Implementation + +```go +func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_issue", /* ... */), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get issue", + resp, + err, + ), nil + } + + return MarshalledTextResult(issue), nil + } +} +``` + +This approach ensures that both the client receives an appropriate error response and any middleware can inspect the underlying GitHub API error for monitoring and debugging purposes. diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index ca38e76b..568af10d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "strings" "syscall" + "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" @@ -90,6 +91,13 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { hooks := &server.Hooks{ OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit}, + OnBeforeAny: []server.BeforeAnyHookFunc{ + func(ctx context.Context, _ any, _ mcp.MCPMethod, _ any) { + // Ensure the context is cleared of any previous errors + // as context isn't propagated through middleware + errors.ContextWithGitHubErrors(ctx) + }, + }, } ghServer := github.NewServer(cfg.Version, server.WithHooks(hooks)) @@ -222,7 +230,8 @@ func RunStdioServer(cfg StdioServerConfig) error { loggedIO := mcplog.NewIOLogger(in, out, logrusLogger) in, out = loggedIO, loggedIO } - + // enable GitHub errors in the context + ctx := errors.ContextWithGitHubErrors(ctx) errC <- stdioServer.Listen(ctx, in, out) }() diff --git a/pkg/errors/error.go b/pkg/errors/error.go index d5d93a28..9d81e901 100644 --- a/pkg/errors/error.go +++ b/pkg/errors/error.go @@ -1,9 +1,11 @@ package errors import ( + "context" "fmt" "github.com/google/go-github/v72/github" + "github.com/mark3labs/mcp-go/mcp" ) type GitHubAPIError struct { @@ -12,7 +14,8 @@ type GitHubAPIError struct { Err error `json:"-"` } -func NewGitHubAPIError(message string, resp *github.Response, err error) *GitHubAPIError { +// NewGitHubAPIError creates a new GitHubAPIError with the provided message, response, and error. +func newGitHubAPIError(message string, resp *github.Response, err error) *GitHubAPIError { return &GitHubAPIError{ Message: message, Response: resp, @@ -29,7 +32,7 @@ type GitHubGraphQLError struct { Err error `json:"-"` } -func NewGitHubGraphQLError(message string, err error) *GitHubGraphQLError { +func newGitHubGraphQLError(message string, err error) *GitHubGraphQLError { return &GitHubGraphQLError{ Message: message, Err: err, @@ -39,3 +42,84 @@ func NewGitHubGraphQLError(message string, err error) *GitHubGraphQLError { func (e *GitHubGraphQLError) Error() string { return fmt.Errorf("%s: %w", e.Message, e.Err).Error() } + +type GitHubErrorKey struct{} +type GitHubCtxErrors struct { + api []*GitHubAPIError + graphQL []*GitHubGraphQLError +} + +// ContextWithGitHubErrors updates or creates a context with a pointer to GitHub error information (to be used by middleware). +func ContextWithGitHubErrors(ctx context.Context) context.Context { + if ctx == nil { + ctx = context.Background() + } + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + // If the context already has GitHubCtxErrors, we just empty the slices to start fresh + val.api = []*GitHubAPIError{} + val.graphQL = []*GitHubGraphQLError{} + } else { + // If not, we create a new GitHubCtxErrors and set it in the context + ctx = context.WithValue(ctx, GitHubErrorKey{}, &GitHubCtxErrors{}) + } + + return ctx +} + +// GetGitHubAPIErrors retrieves the slice of GitHubAPIErrors from the context. +func GetGitHubAPIErrors(ctx context.Context) ([]*GitHubAPIError, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + return val.api, nil // return the slice of API errors from the context + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +// GetGitHubGraphQLErrors retrieves the slice of GitHubGraphQLErrors from the context. +func GetGitHubGraphQLErrors(ctx context.Context) ([]*GitHubGraphQLError, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + return val.graphQL, nil // return the slice of GraphQL errors from the context + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +func NewGitHubAPIErrorToCtx(ctx context.Context, message string, resp *github.Response, err error) (context.Context, error) { + apiErr := newGitHubAPIError(message, resp, err) + if ctx != nil { + _, _ = addGitHubAPIErrorToContext(ctx, apiErr) // Explicitly ignore error for graceful handling + } + return ctx, nil +} + +func addGitHubAPIErrorToContext(ctx context.Context, err *GitHubAPIError) (context.Context, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + val.api = append(val.api, err) // append the error to the existing slice in the context + return ctx, nil + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +func addGitHubGraphQLErrorToContext(ctx context.Context, err *GitHubGraphQLError) (context.Context, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + val.graphQL = append(val.graphQL, err) // append the error to the existing slice in the context + return ctx, nil + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +// NewGitHubAPIErrorResponse returns an mcp.NewToolResultError and retains the error in the context for access via middleware +func NewGitHubAPIErrorResponse(ctx context.Context, message string, resp *github.Response, err error) *mcp.CallToolResult { + apiErr := newGitHubAPIError(message, resp, err) + if ctx != nil { + _, _ = addGitHubAPIErrorToContext(ctx, apiErr) // Explicitly ignore error for graceful handling + } + return mcp.NewToolResultErrorFromErr(message, err) +} + +// NewGitHubGraphQLErrorResponse returns an mcp.NewToolResultError and retains the error in the context for access via middleware +func NewGitHubGraphQLErrorResponse(ctx context.Context, message string, err error) *mcp.CallToolResult { + graphQLErr := newGitHubGraphQLError(message, err) + if ctx != nil { + _, _ = addGitHubGraphQLErrorToContext(ctx, graphQLErr) // Explicitly ignore error for graceful handling + } + return mcp.NewToolResultErrorFromErr(message, err) +} diff --git a/pkg/errors/error_test.go b/pkg/errors/error_test.go new file mode 100644 index 00000000..409f2054 --- /dev/null +++ b/pkg/errors/error_test.go @@ -0,0 +1,379 @@ +package errors + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/google/go-github/v72/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitHubErrorContext(t *testing.T) { + t.Run("API errors can be added to context and retrieved", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 404, + Status: "404 Not Found", + }, + } + originalErr := fmt.Errorf("resource not found") + + // When we add an API error to the context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed to fetch resource", resp, originalErr) + require.NoError(t, err) + + // Then we should be able to retrieve the error from the updated context + apiErrors, err := GetGitHubAPIErrors(updatedCtx) + require.NoError(t, err) + require.Len(t, apiErrors, 1) + + apiError := apiErrors[0] + assert.Equal(t, "failed to fetch resource", apiError.Message) + assert.Equal(t, resp, apiError.Response) + assert.Equal(t, originalErr, apiError.Err) + assert.Equal(t, "failed to fetch resource: resource not found", apiError.Error()) + }) + + t.Run("GraphQL errors can be added to context and retrieved", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + originalErr := fmt.Errorf("GraphQL query failed") + + // When we add a GraphQL error to the context + graphQLErr := newGitHubGraphQLError("failed to execute mutation", originalErr) + updatedCtx, err := addGitHubGraphQLErrorToContext(ctx, graphQLErr) + require.NoError(t, err) + + // Then we should be able to retrieve the error from the updated context + gqlErrors, err := GetGitHubGraphQLErrors(updatedCtx) + require.NoError(t, err) + require.Len(t, gqlErrors, 1) + + gqlError := gqlErrors[0] + assert.Equal(t, "failed to execute mutation", gqlError.Message) + assert.Equal(t, originalErr, gqlError.Err) + assert.Equal(t, "failed to execute mutation: GraphQL query failed", gqlError.Error()) + }) + + t.Run("multiple errors can be accumulated in context", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + // When we add multiple API errors + resp1 := &github.Response{Response: &http.Response{StatusCode: 404}} + resp2 := &github.Response{Response: &http.Response{StatusCode: 403}} + + ctx, err := NewGitHubAPIErrorToCtx(ctx, "first error", resp1, fmt.Errorf("not found")) + require.NoError(t, err) + + ctx, err = NewGitHubAPIErrorToCtx(ctx, "second error", resp2, fmt.Errorf("forbidden")) + require.NoError(t, err) + + // And add a GraphQL error + gqlErr := newGitHubGraphQLError("graphql error", fmt.Errorf("query failed")) + ctx, err = addGitHubGraphQLErrorToContext(ctx, gqlErr) + require.NoError(t, err) + + // Then we should be able to retrieve all errors + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + assert.Len(t, apiErrors, 2) + + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + require.NoError(t, err) + assert.Len(t, gqlErrors, 1) + + // Verify error details + assert.Equal(t, "first error", apiErrors[0].Message) + assert.Equal(t, "second error", apiErrors[1].Message) + assert.Equal(t, "graphql error", gqlErrors[0].Message) + }) + + t.Run("context pointer sharing allows middleware to inspect errors without context propagation", func(t *testing.T) { + // This test demonstrates the key behavior: even when the context itself + // isn't propagated through function calls, the pointer to the error slice + // is shared, allowing middleware to inspect errors that were added later. + + // Given a context with GitHub error tracking enabled + originalCtx := ContextWithGitHubErrors(context.Background()) + + // Simulate a middleware that captures the context early + var middlewareCtx context.Context + + // Middleware function that captures the context + middleware := func(ctx context.Context) { + middlewareCtx = ctx // Middleware saves the context reference + } + + // Call middleware with the original context + middleware(originalCtx) + + // Simulate some business logic that adds errors to the context + // but doesn't propagate the updated context back to middleware + businessLogic := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 500}} + + // Add an error to the context (this modifies the shared pointer) + _, err := NewGitHubAPIErrorToCtx(ctx, "business logic failed", resp, fmt.Errorf("internal error")) + require.NoError(t, err) + + // Add another error + _, err = NewGitHubAPIErrorToCtx(ctx, "second failure", resp, fmt.Errorf("another error")) + require.NoError(t, err) + } + + // Execute business logic - note that we don't propagate the returned context + businessLogic(originalCtx) + + // Then the middleware should be able to see the errors that were added + // even though it only has a reference to the original context + apiErrors, err := GetGitHubAPIErrors(middlewareCtx) + require.NoError(t, err) + assert.Len(t, apiErrors, 2, "Middleware should see errors added after it captured the context") + + assert.Equal(t, "business logic failed", apiErrors[0].Message) + assert.Equal(t, "second failure", apiErrors[1].Message) + }) + + t.Run("context without GitHub errors returns error", func(t *testing.T) { + // Given a regular context without GitHub error tracking + ctx := context.Background() + + // When we try to retrieve errors + apiErrors, err := GetGitHubAPIErrors(ctx) + + // Then it should return an error + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, apiErrors) + + // Same for GraphQL errors + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, gqlErrors) + }) + + t.Run("ContextWithGitHubErrors resets existing errors", func(t *testing.T) { + // Given a context with existing errors + ctx := ContextWithGitHubErrors(context.Background()) + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + ctx, err := NewGitHubAPIErrorToCtx(ctx, "existing error", resp, fmt.Errorf("error")) + require.NoError(t, err) + + // Verify error exists + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + assert.Len(t, apiErrors, 1) + + // When we call ContextWithGitHubErrors again + resetCtx := ContextWithGitHubErrors(ctx) + + // Then the errors should be cleared + apiErrors, err = GetGitHubAPIErrors(resetCtx) + require.NoError(t, err) + assert.Len(t, apiErrors, 0, "Errors should be reset") + }) + + t.Run("NewGitHubAPIErrorResponse creates MCP error result and stores context error", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + resp := &github.Response{Response: &http.Response{StatusCode: 422}} + originalErr := fmt.Errorf("validation failed") + + // When we create an API error response + result := NewGitHubAPIErrorResponse(ctx, "API call failed", resp, originalErr) + + // Then it should return an MCP error result + require.NotNil(t, result) + assert.True(t, result.IsError) + + // And the error should be stored in the context + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + require.Len(t, apiErrors, 1) + + apiError := apiErrors[0] + assert.Equal(t, "API call failed", apiError.Message) + assert.Equal(t, resp, apiError.Response) + assert.Equal(t, originalErr, apiError.Err) + }) + + t.Run("NewGitHubGraphQLErrorResponse creates MCP error result and stores context error", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + originalErr := fmt.Errorf("mutation failed") + + // When we create a GraphQL error response + result := NewGitHubGraphQLErrorResponse(ctx, "GraphQL call failed", originalErr) + + // Then it should return an MCP error result + require.NotNil(t, result) + assert.True(t, result.IsError) + + // And the error should be stored in the context + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + require.NoError(t, err) + require.Len(t, gqlErrors, 1) + + gqlError := gqlErrors[0] + assert.Equal(t, "GraphQL call failed", gqlError.Message) + assert.Equal(t, originalErr, gqlError.Err) + }) + + t.Run("NewGitHubAPIErrorToCtx with uninitialized context does not error", func(t *testing.T) { + // Given a regular context without GitHub error tracking initialized + ctx := context.Background() + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 500, + Status: "500 Internal Server Error", + }, + } + originalErr := fmt.Errorf("internal server error") + + // When we try to add an API error to an uninitialized context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed operation", resp, originalErr) + + // Then it should not return an error (graceful handling) + assert.NoError(t, err, "NewGitHubAPIErrorToCtx should handle uninitialized context gracefully") + assert.Equal(t, ctx, updatedCtx, "Context should be returned unchanged when not initialized") + + // And attempting to retrieve errors should still return an error since context wasn't initialized + apiErrors, err := GetGitHubAPIErrors(updatedCtx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, apiErrors) + }) + + t.Run("NewGitHubAPIErrorToCtx with nil context does not error", func(t *testing.T) { + // Given a nil context + var ctx context.Context = nil + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 400, + Status: "400 Bad Request", + }, + } + originalErr := fmt.Errorf("bad request") + + // When we try to add an API error to a nil context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed with nil context", resp, originalErr) + + // Then it should not return an error (graceful handling) + assert.NoError(t, err, "NewGitHubAPIErrorToCtx should handle nil context gracefully") + assert.Nil(t, updatedCtx, "Context should remain nil when passed as nil") + }) +} + +func TestGitHubErrorTypes(t *testing.T) { + t.Run("GitHubAPIError implements error interface", func(t *testing.T) { + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + originalErr := fmt.Errorf("not found") + + apiErr := newGitHubAPIError("test message", resp, originalErr) + + // Should implement error interface + var err error = apiErr + assert.Equal(t, "test message: not found", err.Error()) + }) + + t.Run("GitHubGraphQLError implements error interface", func(t *testing.T) { + originalErr := fmt.Errorf("query failed") + + gqlErr := newGitHubGraphQLError("test message", originalErr) + + // Should implement error interface + var err error = gqlErr + assert.Equal(t, "test message: query failed", err.Error()) + }) +} + +// TestMiddlewareScenario demonstrates a realistic middleware scenario +func TestMiddlewareScenario(t *testing.T) { + t.Run("realistic middleware error collection scenario", func(t *testing.T) { + // Simulate a realistic HTTP middleware scenario + + // 1. Request comes in, middleware sets up error tracking + ctx := ContextWithGitHubErrors(context.Background()) + + // 2. Middleware stores reference to context for later inspection + var middlewareCtx context.Context + setupMiddleware := func(ctx context.Context) context.Context { + middlewareCtx = ctx + return ctx + } + + // 3. Setup middleware + ctx = setupMiddleware(ctx) + + // 4. Simulate multiple service calls that add errors + simulateServiceCall1 := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 403}} + _, err := NewGitHubAPIErrorToCtx(ctx, "insufficient permissions", resp, fmt.Errorf("forbidden")) + require.NoError(t, err) + } + + simulateServiceCall2 := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + _, err := NewGitHubAPIErrorToCtx(ctx, "resource not found", resp, fmt.Errorf("not found")) + require.NoError(t, err) + } + + simulateGraphQLCall := func(ctx context.Context) { + gqlErr := newGitHubGraphQLError("mutation failed", fmt.Errorf("invalid input")) + _, err := addGitHubGraphQLErrorToContext(ctx, gqlErr) + require.NoError(t, err) + } + + // 5. Execute service calls (without context propagation) + simulateServiceCall1(ctx) + simulateServiceCall2(ctx) + simulateGraphQLCall(ctx) + + // 6. Middleware inspects errors at the end of request processing + finalizeMiddleware := func(ctx context.Context) ([]string, []string) { + var apiErrorMessages []string + var gqlErrorMessages []string + + if apiErrors, err := GetGitHubAPIErrors(ctx); err == nil { + for _, apiErr := range apiErrors { + apiErrorMessages = append(apiErrorMessages, apiErr.Message) + } + } + + if gqlErrors, err := GetGitHubGraphQLErrors(ctx); err == nil { + for _, gqlErr := range gqlErrors { + gqlErrorMessages = append(gqlErrorMessages, gqlErr.Message) + } + } + + return apiErrorMessages, gqlErrorMessages + } + + // 7. Middleware can see all errors that were added during request processing + apiMessages, gqlMessages := finalizeMiddleware(middlewareCtx) + + // Verify all errors were captured + assert.Len(t, apiMessages, 2) + assert.Contains(t, apiMessages, "insufficient permissions") + assert.Contains(t, apiMessages, "resource not found") + + assert.Len(t, gqlMessages, 1) + assert.Contains(t, gqlMessages, "mutation failed") + }) +} diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 527a426e..cf33fb5a 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -644,7 +645,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo Filter: "latest", }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list workflow jobs: %v", err)), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow jobs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -670,7 +671,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo // Collect logs for all failed jobs var logResults []map[string]any for _, job := range failedJobs { - jobResult, err := getJobLogData(ctx, client, owner, repo, job.GetID(), job.GetName(), returnContent) + jobResult, resp, err := getJobLogData(ctx, client, owner, repo, job.GetID(), job.GetName(), returnContent) if err != nil { // Continue with other jobs even if one fails jobResult = map[string]any{ @@ -678,7 +679,10 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo "job_name": job.GetName(), "error": err.Error(), } + // Enable reporting of status codes and error causes + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get job logs", resp, err) // Explicitly ignore error for graceful handling } + logResults = append(logResults, jobResult) } @@ -701,9 +705,9 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo // handleSingleJobLogs gets logs for a single job func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo string, jobID int64, returnContent bool) (*mcp.CallToolResult, error) { - jobResult, err := getJobLogData(ctx, client, owner, repo, jobID, "", returnContent) + jobResult, resp, err := getJobLogData(ctx, client, owner, repo, jobID, "", returnContent) if err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get job logs", resp, err), nil } r, err := json.Marshal(jobResult) @@ -715,11 +719,11 @@ func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo } // getJobLogData retrieves log data for a single job, either as URL or content -func getJobLogData(ctx context.Context, client *github.Client, owner, repo string, jobID int64, jobName string, returnContent bool) (map[string]any, error) { +func getJobLogData(ctx context.Context, client *github.Client, owner, repo string, jobID int64, jobName string, returnContent bool) (map[string]any, *github.Response, error) { // Get the download URL for the job logs url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, owner, repo, jobID, 1) if err != nil { - return nil, fmt.Errorf("failed to get job logs for job %d: %w", jobID, err) + return nil, resp, fmt.Errorf("failed to get job logs for job %d: %w", jobID, err) } defer func() { _ = resp.Body.Close() }() @@ -732,9 +736,13 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin if returnContent { // Download and return the actual log content - content, err := downloadLogContent(url.String()) + content, httpResp, err := downloadLogContent(url.String()) //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp if err != nil { - return nil, fmt.Errorf("failed to download log content for job %d: %w", jobID, err) + // To keep the return value consistent wrap the response as a GitHub Response + ghRes := &github.Response{ + Response: httpResp, + } + return nil, ghRes, fmt.Errorf("failed to download log content for job %d: %w", jobID, err) } result["logs_content"] = content result["message"] = "Job logs content retrieved successfully" @@ -745,29 +753,29 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin result["note"] = "The logs_url provides a download link for the individual job logs in plain text format. Use return_content=true to get the actual log content." } - return result, nil + return result, resp, nil } // downloadLogContent downloads the actual log content from a GitHub logs URL -func downloadLogContent(logURL string) (string, error) { +func downloadLogContent(logURL string) (string, *http.Response, error) { httpResp, err := http.Get(logURL) //nolint:gosec // URLs are provided by GitHub API and are safe if err != nil { - return "", fmt.Errorf("failed to download logs: %w", err) + return "", httpResp, fmt.Errorf("failed to download logs: %w", err) } defer func() { _ = httpResp.Body.Close() }() if httpResp.StatusCode != http.StatusOK { - return "", fmt.Errorf("failed to download logs: HTTP %d", httpResp.StatusCode) + return "", httpResp, fmt.Errorf("failed to download logs: HTTP %d", httpResp.StatusCode) } content, err := io.ReadAll(httpResp.Body) if err != nil { - return "", fmt.Errorf("failed to read log content: %w", err) + return "", httpResp, fmt.Errorf("failed to read log content: %w", err) } // Clean up and format the log content for better readability logContent := strings.TrimSpace(string(content)) - return logContent, nil + return logContent, httpResp, nil } // RerunWorkflowRun creates a tool to re-run an entire workflow run @@ -813,7 +821,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to rerun workflow run: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -876,7 +884,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to rerun failed jobs: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -939,7 +947,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to cancel workflow run: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1024,7 +1032,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) if err != nil { - return nil, fmt.Errorf("failed to list workflow run artifacts: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1081,7 +1089,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati // Get the download URL for the artifact url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) if err != nil { - return nil, fmt.Errorf("failed to get artifact download URL: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1146,7 +1154,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to delete workflow run logs: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1209,7 +1217,7 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to get workflow run usage: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index e2110d3d..3b07692c 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -55,11 +55,11 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get alert", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -143,11 +143,11 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list alerts", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 5c0131a7..bd76ccba 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -94,12 +94,15 @@ func Test_GetCodeScanningAlert(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -217,12 +220,15 @@ func Test_ListCodeScanningAlerts(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index b35a5c14..bed2f4a3 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -31,11 +31,11 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too user, res, err := client.Users.Get(ctx, "") if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get user", res, err, - ) + ), nil } return MarshalledTextResult(user), nil diff --git a/pkg/github/issues.go b/pkg/github/issues.go index d513fa45..b4c64c8d 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -9,7 +9,6 @@ import ( "strings" "time" - ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v72/github" @@ -59,11 +58,7 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - fmt.Sprintf("failed to get issue with number '%d'", issueNumber), - resp, - err, - ) + return nil, fmt.Errorf("failed to get issue: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -137,11 +132,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc } createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - fmt.Sprintf("failed to create comment on issue '%d'", issueNumber), - resp, - err, - ) + return nil, fmt.Errorf("failed to create comment: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -229,11 +220,7 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( } result, resp, err := client.Search.Issues(ctx, query, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - "failed to search issues", - resp, - err, - ) + return nil, fmt.Errorf("failed to search issues: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -355,11 +342,7 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t } issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - "failed to create issue", - resp, - err, - ) + return nil, fmt.Errorf("failed to create issue: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -481,11 +464,7 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to } issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - "failed to list issues", - resp, - err, - ) + return nil, fmt.Errorf("failed to list issues: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -631,11 +610,7 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t } updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - "failed to update issue", - resp, - err, - ) + return nil, fmt.Errorf("failed to update issue: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -718,11 +693,7 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun } comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( - "failed to get issue comments", - resp, - err, - ) + return nil, fmt.Errorf("failed to get issue comments: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -853,10 +824,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio var query suggestedActorsQuery err := client.Query(ctx, &query, variables) if err != nil { - return nil, ghErrors.NewGitHubGraphQLError( - "failed to list suggested actors", - err, - ) + return nil, err } // Iterate all the returned nodes looking for the copilot bot, which is supposed to have the @@ -902,10 +870,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio } if err := client.Query(ctx, &getIssueQuery, variables); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( - "failed to get issue ID", - err, - ) + return mcp.NewToolResultError(fmt.Sprintf("failed to get issue ID: %v", err)), nil } // Finally, do the assignment. Just for reference, assigning copilot to an issue that it is already @@ -931,10 +896,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio }, nil, ); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( - "failed to replace actors for assignable", - err, - ) + return nil, fmt.Errorf("failed to replace actors for assignable: %w", err) } return mcp.NewToolResultText("successfully assigned copilot to issue"), nil diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 9b81878f..b6b6bfd7 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -119,11 +119,11 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu notifications, resp, err = client.Activity.ListNotifications(ctx, opts) } if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list notifications", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -192,11 +192,11 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper } if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to mark notification as %s", state), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -271,11 +271,11 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) } if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to mark all notifications as read", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -317,11 +317,11 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel thread, resp, err := client.Activity.GetThread(ctx, notificationID) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -402,11 +402,11 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl } if apiErr != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to %s notification subscription", action), resp, apiErr, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -495,11 +495,11 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati } if apiErr != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to %s repository subscription", action), resp, apiErr, - ) + ), nil } if resp != nil { defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 77372f02..a83df3ed 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -127,14 +127,17 @@ func Test_ListNotifications(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) t.Logf("textContent: %s", textContent.Text) var returned []*github.Notification @@ -663,14 +666,17 @@ func Test_MarkAllNotificationsRead(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) if tc.expectMarked { assert.Contains(t, textContent.Text, "All notifications marked as read") @@ -738,14 +744,17 @@ func Test_GetNotificationDetails(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) var returned github.Notification err = json.Unmarshal([]byte(textContent.Text), &returned) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 89a3c1bf..7dcc2c4f 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -58,11 +58,11 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -177,11 +177,11 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create pull request", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -302,11 +302,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update pull request", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -415,11 +415,11 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list pull requests", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -508,11 +508,11 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to merge pull request", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -575,11 +575,11 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request files", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -641,11 +641,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -660,11 +660,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe // Get combined status for the head SHA status, resp, err := client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get combined status", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -742,11 +742,11 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Pull request branch update is in progress"), nil } - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update pull request branch", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -814,11 +814,11 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request comments", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -880,11 +880,11 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request reviews", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -971,10 +971,10 @@ func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translation "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get pull request", err, - ) + ), nil } // Now we have the GQL ID, we can create a review @@ -1065,10 +1065,10 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get pull request", err, - ) + ), nil } // Now we have the GQL ID, we can create a pending review @@ -1186,10 +1186,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get current user", err, - ) + ), nil } var getLatestReviewForViewerQuery struct { @@ -1214,10 +1214,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get latest review for current user", err, - ) + ), nil } // Validate there is one review and the state is pending @@ -1323,10 +1323,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get current user", err, - ) + ), nil } var getLatestReviewForViewerQuery struct { @@ -1351,10 +1351,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get latest review for current user", err, - ) + ), nil } // Validate there is one review and the state is pending @@ -1387,10 +1387,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. }, nil, ); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to submit pull request review", err, - ) + ), nil } // Return nothing interesting, just indicate success for the time being. @@ -1447,10 +1447,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get current user", err, - ) + ), nil } var getLatestReviewForViewerQuery struct { @@ -1475,10 +1475,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return nil, ghErrors.NewGitHubGraphQLError( + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "failed to get latest review for current user", err, - ) + ), nil } // Validate there is one review and the state is pending @@ -1562,11 +1562,11 @@ func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperF github.RawOptions{Type: github.Diff}, ) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request diff", resp, err, - ) + ), nil } if resp.StatusCode != http.StatusOK { @@ -1639,11 +1639,11 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe }, ) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to request copilot review", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 144c6b38..02575c43 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -109,12 +109,15 @@ func Test_GetPullRequest(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -272,23 +275,22 @@ func Test_UpdatePullRequest(t *testing.T) { result, err := handler(context.Background(), request) // Verify results - if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content textContent := getTextResult(t, result) - // Check for expected error message within the result text - if tc.expectedErrMsg != "" { - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - return - } - // Unmarshal and verify the successful result var returnedPR github.PullRequest err = json.Unmarshal([]byte(textContent.Text), &returnedPR) @@ -420,12 +422,15 @@ func Test_ListPullRequests(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -536,12 +541,15 @@ func Test_MergePullRequest(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -649,12 +657,15 @@ func Test_GetPullRequestFiles(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -810,12 +821,15 @@ func Test_GetPullRequestStatus(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -938,12 +952,15 @@ func Test_UpdatePullRequestBranch(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1055,12 +1072,15 @@ func Test_GetPullRequestComments(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1179,12 +1199,15 @@ func Test_GetPullRequestReviews(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1653,12 +1676,15 @@ func Test_RequestCopilotReview(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) assert.Len(t, result.Content, 1) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 52b2b4cc..fa5d7338 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -69,11 +69,11 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too } commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to get commit: %s", sha), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -155,11 +155,11 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to list commits: %s", sha), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -226,11 +226,11 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list branches", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -339,11 +339,11 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create/update file", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -417,11 +417,11 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create repository", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -662,11 +662,11 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Fork is in progress"), nil } - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to fork repository", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -759,11 +759,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get base commit", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -788,11 +788,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Create a new tree with the deletion newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create tree", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -812,11 +812,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create commit", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -832,11 +832,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to ref.Object.SHA = newCommit.SHA _, resp, err = client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update reference", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -917,11 +917,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get default branch if from_branch not specified repository, resp, err := client.Repositories.Get(ctx, owner, repo) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get repository", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -931,11 +931,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get SHA of source branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get reference", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -947,11 +947,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create branch", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1041,22 +1041,22 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Get the reference for the branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get branch reference", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get base commit", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1091,11 +1091,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Create a new tree with the file entries newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create tree", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1107,11 +1107,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create commit", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1119,11 +1119,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too ref.Object.SHA = newCommit.SHA updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update reference", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1180,11 +1180,11 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list tags", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1248,11 +1248,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // First get the tag reference ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get tag reference", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1267,11 +1267,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // Then get the tag object tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get tag object", resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 7ce2fec1..b621cec4 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -302,12 +302,15 @@ func Test_ForkRepository(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -489,12 +492,15 @@ func Test_CreateBranch(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -612,12 +618,15 @@ func Test_GetCommit(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -788,12 +797,15 @@ func Test_ListCommits(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -951,12 +963,15 @@ func Test_CreateOrUpdateFile(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1100,12 +1115,15 @@ func Test_CreateRepository(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1434,19 +1452,23 @@ func Test_PushFiles(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } if tc.expectedErrMsg != "" { require.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1847,12 +1869,15 @@ func Test_ListTags(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1998,12 +2023,15 @@ func Test_GetTag(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) diff --git a/pkg/github/search.go b/pkg/github/search.go index d10dfffc..13d01712 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -50,11 +50,11 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } result, resp, err := client.Search.Repositories(ctx, query, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to search repositories with query '%s'", query), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -130,11 +130,11 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to result, resp, err := client.Search.Code(ctx, query, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to search code with query '%s'", query), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -224,11 +224,11 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t result, resp, err := client.Search.Users(ctx, "type:user "+query, opts) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to search users with query '%s'", query), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index b76fe804..f206ebb4 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -134,12 +134,15 @@ func Test_SearchRepositories(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -286,12 +289,15 @@ func Test_SearchCode(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -437,12 +443,15 @@ func Test_SearchUsers(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error require.NotNil(t, result) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index ef7901d4..bea6df2a 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -56,11 +56,11 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to get alert with number '%d'", alertNumber), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() @@ -137,11 +137,11 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH } alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) if err != nil { - return nil, ghErrors.NewGitHubAPIError( + return ghErrors.NewGitHubAPIErrorResponse(ctx, fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), resp, err, - ) + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 4ec5539e..38b573e0 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -90,12 +90,15 @@ func Test_GetSecretScanningAlert(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -217,12 +220,15 @@ func Test_ListSecretScanningAlerts(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) From 2a2df24af2b0d83cb6156842266d9f3b6f263dd2 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 24 Jun 2025 16:16:26 +0200 Subject: [PATCH 3/3] update tool snaps --- pkg/github/__toolsnaps__/create_or_update_file.snap | 2 +- pkg/github/__toolsnaps__/list_commits.snap | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/github/__toolsnaps__/create_or_update_file.snap b/pkg/github/__toolsnaps__/create_or_update_file.snap index 53f643df..dfbb3442 100644 --- a/pkg/github/__toolsnaps__/create_or_update_file.snap +++ b/pkg/github/__toolsnaps__/create_or_update_file.snap @@ -3,7 +3,7 @@ "title": "Create or update file", "readOnlyHint": false }, - "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update.", + "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations.", "inputSchema": { "properties": { "branch": { diff --git a/pkg/github/__toolsnaps__/list_commits.snap b/pkg/github/__toolsnaps__/list_commits.snap index 6603bdf5..1e769c71 100644 --- a/pkg/github/__toolsnaps__/list_commits.snap +++ b/pkg/github/__toolsnaps__/list_commits.snap @@ -6,6 +6,10 @@ "description": "Get list of commits of a branch in a GitHub repository", "inputSchema": { "properties": { + "author": { + "description": "Author username or email address", + "type": "string" + }, "owner": { "description": "Repository owner", "type": "string" @@ -28,10 +32,6 @@ "sha": { "description": "SHA or Branch name", "type": "string" - }, - "author": { - "description": "Author username or email address", - "type": "string" } }, "required": [