diff --git a/docs/error-handling.md b/docs/error-handling.md new file mode 100644 index 00000000..9bb27e0f --- /dev/null +++ b/docs/error-handling.md @@ -0,0 +1,125 @@ +# Error Handling + +This document describes the error handling patterns used in the GitHub MCP Server, specifically how we handle GitHub API errors and avoid direct use of mcp-go error types. + +## Overview + +The GitHub MCP Server implements a custom error handling approach that serves two primary purposes: + +1. **Tool Response Generation**: Return appropriate MCP tool error responses to clients +2. **Middleware Inspection**: Store detailed error information in the request context for middleware analysis + +This dual approach enables better observability and debugging capabilities, particularly for remote server deployments where understanding the nature of failures (rate limiting, authentication, 404s, 500s, etc.) is crucial for validation and monitoring. + +## Error Types + +### GitHubAPIError + +Used for REST API errors from the GitHub API: + +```go +type GitHubAPIError struct { + Message string `json:"message"` + Response *github.Response `json:"-"` + Err error `json:"-"` +} +``` + +### GitHubGraphQLError + +Used for GraphQL API errors from the GitHub API: + +```go +type GitHubGraphQLError struct { + Message string `json:"message"` + Err error `json:"-"` +} +``` + +## Usage Patterns + +### For GitHub REST API Errors + +Instead of directly returning `mcp.NewToolResultError()`, use: + +```go +return ghErrors.NewGitHubAPIErrorResponse(ctx, message, response, err), nil +``` + +This function: +- Creates a `GitHubAPIError` with the provided message, response, and error +- Stores the error in the context for middleware inspection +- Returns an appropriate MCP tool error response + +### For GitHub GraphQL API Errors + +```go +return ghErrors.NewGitHubGraphQLErrorResponse(ctx, message, err), nil +``` + +### Context Management + +The error handling system uses context to store errors for later inspection: + +```go +// Initialize context with error tracking +ctx = errors.ContextWithGitHubErrors(ctx) + +// Retrieve errors for inspection (typically in middleware) +apiErrors, err := errors.GetGitHubAPIErrors(ctx) +graphqlErrors, err := errors.GetGitHubGraphQLErrors(ctx) +``` + +## Design Principles + +### User-Actionable vs. Developer Errors + +- **User-actionable errors** (authentication failures, rate limits, 404s) should be returned as failed tool calls using the error response functions +- **Developer errors** (JSON marshaling failures, internal logic errors) should be returned as actual Go errors that bubble up through the MCP framework + +### Context Limitations + +This approach was designed to work around current limitations in mcp-go where context is not propagated through each step of request processing. By storing errors in context values, middleware can inspect them without requiring context propagation. + +### Graceful Error Handling + +Error storage operations in context are designed to fail gracefully - if context storage fails, the tool will still return an appropriate error response to the client. + +## Benefits + +1. **Observability**: Middleware can inspect the specific types of GitHub API errors occurring +2. **Debugging**: Detailed error information is preserved without exposing potentially sensitive data in logs +3. **Validation**: Remote servers can use error types and HTTP status codes to validate that changes don't break functionality +4. **Privacy**: Error inspection can be done programmatically using `errors.Is` checks without logging PII + +## Example Implementation + +```go +func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_issue", /* ... */), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get issue", + resp, + err, + ), nil + } + + return MarshalledTextResult(issue), nil + } +} +``` + +This approach ensures that both the client receives an appropriate error response and any middleware can inspect the underlying GitHub API error for monitoring and debugging purposes. diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index ca38e76b..568af10d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "strings" "syscall" + "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" @@ -90,6 +91,13 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { hooks := &server.Hooks{ OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit}, + OnBeforeAny: []server.BeforeAnyHookFunc{ + func(ctx context.Context, _ any, _ mcp.MCPMethod, _ any) { + // Ensure the context is cleared of any previous errors + // as context isn't propagated through middleware + errors.ContextWithGitHubErrors(ctx) + }, + }, } ghServer := github.NewServer(cfg.Version, server.WithHooks(hooks)) @@ -222,7 +230,8 @@ func RunStdioServer(cfg StdioServerConfig) error { loggedIO := mcplog.NewIOLogger(in, out, logrusLogger) in, out = loggedIO, loggedIO } - + // enable GitHub errors in the context + ctx := errors.ContextWithGitHubErrors(ctx) errC <- stdioServer.Listen(ctx, in, out) }() diff --git a/pkg/errors/error.go b/pkg/errors/error.go new file mode 100644 index 00000000..9d81e901 --- /dev/null +++ b/pkg/errors/error.go @@ -0,0 +1,125 @@ +package errors + +import ( + "context" + "fmt" + + "github.com/google/go-github/v72/github" + "github.com/mark3labs/mcp-go/mcp" +) + +type GitHubAPIError struct { + Message string `json:"message"` + Response *github.Response `json:"-"` + Err error `json:"-"` +} + +// NewGitHubAPIError creates a new GitHubAPIError with the provided message, response, and error. +func newGitHubAPIError(message string, resp *github.Response, err error) *GitHubAPIError { + return &GitHubAPIError{ + Message: message, + Response: resp, + Err: err, + } +} + +func (e *GitHubAPIError) Error() string { + return fmt.Errorf("%s: %w", e.Message, e.Err).Error() +} + +type GitHubGraphQLError struct { + Message string `json:"message"` + Err error `json:"-"` +} + +func newGitHubGraphQLError(message string, err error) *GitHubGraphQLError { + return &GitHubGraphQLError{ + Message: message, + Err: err, + } +} + +func (e *GitHubGraphQLError) Error() string { + return fmt.Errorf("%s: %w", e.Message, e.Err).Error() +} + +type GitHubErrorKey struct{} +type GitHubCtxErrors struct { + api []*GitHubAPIError + graphQL []*GitHubGraphQLError +} + +// ContextWithGitHubErrors updates or creates a context with a pointer to GitHub error information (to be used by middleware). +func ContextWithGitHubErrors(ctx context.Context) context.Context { + if ctx == nil { + ctx = context.Background() + } + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + // If the context already has GitHubCtxErrors, we just empty the slices to start fresh + val.api = []*GitHubAPIError{} + val.graphQL = []*GitHubGraphQLError{} + } else { + // If not, we create a new GitHubCtxErrors and set it in the context + ctx = context.WithValue(ctx, GitHubErrorKey{}, &GitHubCtxErrors{}) + } + + return ctx +} + +// GetGitHubAPIErrors retrieves the slice of GitHubAPIErrors from the context. +func GetGitHubAPIErrors(ctx context.Context) ([]*GitHubAPIError, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + return val.api, nil // return the slice of API errors from the context + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +// GetGitHubGraphQLErrors retrieves the slice of GitHubGraphQLErrors from the context. +func GetGitHubGraphQLErrors(ctx context.Context) ([]*GitHubGraphQLError, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + return val.graphQL, nil // return the slice of GraphQL errors from the context + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +func NewGitHubAPIErrorToCtx(ctx context.Context, message string, resp *github.Response, err error) (context.Context, error) { + apiErr := newGitHubAPIError(message, resp, err) + if ctx != nil { + _, _ = addGitHubAPIErrorToContext(ctx, apiErr) // Explicitly ignore error for graceful handling + } + return ctx, nil +} + +func addGitHubAPIErrorToContext(ctx context.Context, err *GitHubAPIError) (context.Context, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + val.api = append(val.api, err) // append the error to the existing slice in the context + return ctx, nil + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +func addGitHubGraphQLErrorToContext(ctx context.Context, err *GitHubGraphQLError) (context.Context, error) { + if val, ok := ctx.Value(GitHubErrorKey{}).(*GitHubCtxErrors); ok { + val.graphQL = append(val.graphQL, err) // append the error to the existing slice in the context + return ctx, nil + } + return nil, fmt.Errorf("context does not contain GitHubCtxErrors") +} + +// NewGitHubAPIErrorResponse returns an mcp.NewToolResultError and retains the error in the context for access via middleware +func NewGitHubAPIErrorResponse(ctx context.Context, message string, resp *github.Response, err error) *mcp.CallToolResult { + apiErr := newGitHubAPIError(message, resp, err) + if ctx != nil { + _, _ = addGitHubAPIErrorToContext(ctx, apiErr) // Explicitly ignore error for graceful handling + } + return mcp.NewToolResultErrorFromErr(message, err) +} + +// NewGitHubGraphQLErrorResponse returns an mcp.NewToolResultError and retains the error in the context for access via middleware +func NewGitHubGraphQLErrorResponse(ctx context.Context, message string, err error) *mcp.CallToolResult { + graphQLErr := newGitHubGraphQLError(message, err) + if ctx != nil { + _, _ = addGitHubGraphQLErrorToContext(ctx, graphQLErr) // Explicitly ignore error for graceful handling + } + return mcp.NewToolResultErrorFromErr(message, err) +} diff --git a/pkg/errors/error_test.go b/pkg/errors/error_test.go new file mode 100644 index 00000000..409f2054 --- /dev/null +++ b/pkg/errors/error_test.go @@ -0,0 +1,379 @@ +package errors + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/google/go-github/v72/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitHubErrorContext(t *testing.T) { + t.Run("API errors can be added to context and retrieved", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 404, + Status: "404 Not Found", + }, + } + originalErr := fmt.Errorf("resource not found") + + // When we add an API error to the context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed to fetch resource", resp, originalErr) + require.NoError(t, err) + + // Then we should be able to retrieve the error from the updated context + apiErrors, err := GetGitHubAPIErrors(updatedCtx) + require.NoError(t, err) + require.Len(t, apiErrors, 1) + + apiError := apiErrors[0] + assert.Equal(t, "failed to fetch resource", apiError.Message) + assert.Equal(t, resp, apiError.Response) + assert.Equal(t, originalErr, apiError.Err) + assert.Equal(t, "failed to fetch resource: resource not found", apiError.Error()) + }) + + t.Run("GraphQL errors can be added to context and retrieved", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + originalErr := fmt.Errorf("GraphQL query failed") + + // When we add a GraphQL error to the context + graphQLErr := newGitHubGraphQLError("failed to execute mutation", originalErr) + updatedCtx, err := addGitHubGraphQLErrorToContext(ctx, graphQLErr) + require.NoError(t, err) + + // Then we should be able to retrieve the error from the updated context + gqlErrors, err := GetGitHubGraphQLErrors(updatedCtx) + require.NoError(t, err) + require.Len(t, gqlErrors, 1) + + gqlError := gqlErrors[0] + assert.Equal(t, "failed to execute mutation", gqlError.Message) + assert.Equal(t, originalErr, gqlError.Err) + assert.Equal(t, "failed to execute mutation: GraphQL query failed", gqlError.Error()) + }) + + t.Run("multiple errors can be accumulated in context", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + // When we add multiple API errors + resp1 := &github.Response{Response: &http.Response{StatusCode: 404}} + resp2 := &github.Response{Response: &http.Response{StatusCode: 403}} + + ctx, err := NewGitHubAPIErrorToCtx(ctx, "first error", resp1, fmt.Errorf("not found")) + require.NoError(t, err) + + ctx, err = NewGitHubAPIErrorToCtx(ctx, "second error", resp2, fmt.Errorf("forbidden")) + require.NoError(t, err) + + // And add a GraphQL error + gqlErr := newGitHubGraphQLError("graphql error", fmt.Errorf("query failed")) + ctx, err = addGitHubGraphQLErrorToContext(ctx, gqlErr) + require.NoError(t, err) + + // Then we should be able to retrieve all errors + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + assert.Len(t, apiErrors, 2) + + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + require.NoError(t, err) + assert.Len(t, gqlErrors, 1) + + // Verify error details + assert.Equal(t, "first error", apiErrors[0].Message) + assert.Equal(t, "second error", apiErrors[1].Message) + assert.Equal(t, "graphql error", gqlErrors[0].Message) + }) + + t.Run("context pointer sharing allows middleware to inspect errors without context propagation", func(t *testing.T) { + // This test demonstrates the key behavior: even when the context itself + // isn't propagated through function calls, the pointer to the error slice + // is shared, allowing middleware to inspect errors that were added later. + + // Given a context with GitHub error tracking enabled + originalCtx := ContextWithGitHubErrors(context.Background()) + + // Simulate a middleware that captures the context early + var middlewareCtx context.Context + + // Middleware function that captures the context + middleware := func(ctx context.Context) { + middlewareCtx = ctx // Middleware saves the context reference + } + + // Call middleware with the original context + middleware(originalCtx) + + // Simulate some business logic that adds errors to the context + // but doesn't propagate the updated context back to middleware + businessLogic := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 500}} + + // Add an error to the context (this modifies the shared pointer) + _, err := NewGitHubAPIErrorToCtx(ctx, "business logic failed", resp, fmt.Errorf("internal error")) + require.NoError(t, err) + + // Add another error + _, err = NewGitHubAPIErrorToCtx(ctx, "second failure", resp, fmt.Errorf("another error")) + require.NoError(t, err) + } + + // Execute business logic - note that we don't propagate the returned context + businessLogic(originalCtx) + + // Then the middleware should be able to see the errors that were added + // even though it only has a reference to the original context + apiErrors, err := GetGitHubAPIErrors(middlewareCtx) + require.NoError(t, err) + assert.Len(t, apiErrors, 2, "Middleware should see errors added after it captured the context") + + assert.Equal(t, "business logic failed", apiErrors[0].Message) + assert.Equal(t, "second failure", apiErrors[1].Message) + }) + + t.Run("context without GitHub errors returns error", func(t *testing.T) { + // Given a regular context without GitHub error tracking + ctx := context.Background() + + // When we try to retrieve errors + apiErrors, err := GetGitHubAPIErrors(ctx) + + // Then it should return an error + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, apiErrors) + + // Same for GraphQL errors + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, gqlErrors) + }) + + t.Run("ContextWithGitHubErrors resets existing errors", func(t *testing.T) { + // Given a context with existing errors + ctx := ContextWithGitHubErrors(context.Background()) + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + ctx, err := NewGitHubAPIErrorToCtx(ctx, "existing error", resp, fmt.Errorf("error")) + require.NoError(t, err) + + // Verify error exists + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + assert.Len(t, apiErrors, 1) + + // When we call ContextWithGitHubErrors again + resetCtx := ContextWithGitHubErrors(ctx) + + // Then the errors should be cleared + apiErrors, err = GetGitHubAPIErrors(resetCtx) + require.NoError(t, err) + assert.Len(t, apiErrors, 0, "Errors should be reset") + }) + + t.Run("NewGitHubAPIErrorResponse creates MCP error result and stores context error", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + resp := &github.Response{Response: &http.Response{StatusCode: 422}} + originalErr := fmt.Errorf("validation failed") + + // When we create an API error response + result := NewGitHubAPIErrorResponse(ctx, "API call failed", resp, originalErr) + + // Then it should return an MCP error result + require.NotNil(t, result) + assert.True(t, result.IsError) + + // And the error should be stored in the context + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + require.Len(t, apiErrors, 1) + + apiError := apiErrors[0] + assert.Equal(t, "API call failed", apiError.Message) + assert.Equal(t, resp, apiError.Response) + assert.Equal(t, originalErr, apiError.Err) + }) + + t.Run("NewGitHubGraphQLErrorResponse creates MCP error result and stores context error", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + originalErr := fmt.Errorf("mutation failed") + + // When we create a GraphQL error response + result := NewGitHubGraphQLErrorResponse(ctx, "GraphQL call failed", originalErr) + + // Then it should return an MCP error result + require.NotNil(t, result) + assert.True(t, result.IsError) + + // And the error should be stored in the context + gqlErrors, err := GetGitHubGraphQLErrors(ctx) + require.NoError(t, err) + require.Len(t, gqlErrors, 1) + + gqlError := gqlErrors[0] + assert.Equal(t, "GraphQL call failed", gqlError.Message) + assert.Equal(t, originalErr, gqlError.Err) + }) + + t.Run("NewGitHubAPIErrorToCtx with uninitialized context does not error", func(t *testing.T) { + // Given a regular context without GitHub error tracking initialized + ctx := context.Background() + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 500, + Status: "500 Internal Server Error", + }, + } + originalErr := fmt.Errorf("internal server error") + + // When we try to add an API error to an uninitialized context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed operation", resp, originalErr) + + // Then it should not return an error (graceful handling) + assert.NoError(t, err, "NewGitHubAPIErrorToCtx should handle uninitialized context gracefully") + assert.Equal(t, ctx, updatedCtx, "Context should be returned unchanged when not initialized") + + // And attempting to retrieve errors should still return an error since context wasn't initialized + apiErrors, err := GetGitHubAPIErrors(updatedCtx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context does not contain GitHubCtxErrors") + assert.Nil(t, apiErrors) + }) + + t.Run("NewGitHubAPIErrorToCtx with nil context does not error", func(t *testing.T) { + // Given a nil context + var ctx context.Context = nil + + // Create a mock GitHub response + resp := &github.Response{ + Response: &http.Response{ + StatusCode: 400, + Status: "400 Bad Request", + }, + } + originalErr := fmt.Errorf("bad request") + + // When we try to add an API error to a nil context + updatedCtx, err := NewGitHubAPIErrorToCtx(ctx, "failed with nil context", resp, originalErr) + + // Then it should not return an error (graceful handling) + assert.NoError(t, err, "NewGitHubAPIErrorToCtx should handle nil context gracefully") + assert.Nil(t, updatedCtx, "Context should remain nil when passed as nil") + }) +} + +func TestGitHubErrorTypes(t *testing.T) { + t.Run("GitHubAPIError implements error interface", func(t *testing.T) { + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + originalErr := fmt.Errorf("not found") + + apiErr := newGitHubAPIError("test message", resp, originalErr) + + // Should implement error interface + var err error = apiErr + assert.Equal(t, "test message: not found", err.Error()) + }) + + t.Run("GitHubGraphQLError implements error interface", func(t *testing.T) { + originalErr := fmt.Errorf("query failed") + + gqlErr := newGitHubGraphQLError("test message", originalErr) + + // Should implement error interface + var err error = gqlErr + assert.Equal(t, "test message: query failed", err.Error()) + }) +} + +// TestMiddlewareScenario demonstrates a realistic middleware scenario +func TestMiddlewareScenario(t *testing.T) { + t.Run("realistic middleware error collection scenario", func(t *testing.T) { + // Simulate a realistic HTTP middleware scenario + + // 1. Request comes in, middleware sets up error tracking + ctx := ContextWithGitHubErrors(context.Background()) + + // 2. Middleware stores reference to context for later inspection + var middlewareCtx context.Context + setupMiddleware := func(ctx context.Context) context.Context { + middlewareCtx = ctx + return ctx + } + + // 3. Setup middleware + ctx = setupMiddleware(ctx) + + // 4. Simulate multiple service calls that add errors + simulateServiceCall1 := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 403}} + _, err := NewGitHubAPIErrorToCtx(ctx, "insufficient permissions", resp, fmt.Errorf("forbidden")) + require.NoError(t, err) + } + + simulateServiceCall2 := func(ctx context.Context) { + resp := &github.Response{Response: &http.Response{StatusCode: 404}} + _, err := NewGitHubAPIErrorToCtx(ctx, "resource not found", resp, fmt.Errorf("not found")) + require.NoError(t, err) + } + + simulateGraphQLCall := func(ctx context.Context) { + gqlErr := newGitHubGraphQLError("mutation failed", fmt.Errorf("invalid input")) + _, err := addGitHubGraphQLErrorToContext(ctx, gqlErr) + require.NoError(t, err) + } + + // 5. Execute service calls (without context propagation) + simulateServiceCall1(ctx) + simulateServiceCall2(ctx) + simulateGraphQLCall(ctx) + + // 6. Middleware inspects errors at the end of request processing + finalizeMiddleware := func(ctx context.Context) ([]string, []string) { + var apiErrorMessages []string + var gqlErrorMessages []string + + if apiErrors, err := GetGitHubAPIErrors(ctx); err == nil { + for _, apiErr := range apiErrors { + apiErrorMessages = append(apiErrorMessages, apiErr.Message) + } + } + + if gqlErrors, err := GetGitHubGraphQLErrors(ctx); err == nil { + for _, gqlErr := range gqlErrors { + gqlErrorMessages = append(gqlErrorMessages, gqlErr.Message) + } + } + + return apiErrorMessages, gqlErrorMessages + } + + // 7. Middleware can see all errors that were added during request processing + apiMessages, gqlMessages := finalizeMiddleware(middlewareCtx) + + // Verify all errors were captured + assert.Len(t, apiMessages, 2) + assert.Contains(t, apiMessages, "insufficient permissions") + assert.Contains(t, apiMessages, "resource not found") + + assert.Len(t, gqlMessages, 1) + assert.Contains(t, gqlMessages, "mutation failed") + }) +} diff --git a/pkg/github/__toolsnaps__/create_or_update_file.snap b/pkg/github/__toolsnaps__/create_or_update_file.snap index 53f643df..dfbb3442 100644 --- a/pkg/github/__toolsnaps__/create_or_update_file.snap +++ b/pkg/github/__toolsnaps__/create_or_update_file.snap @@ -3,7 +3,7 @@ "title": "Create or update file", "readOnlyHint": false }, - "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update.", + "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations.", "inputSchema": { "properties": { "branch": { diff --git a/pkg/github/__toolsnaps__/list_commits.snap b/pkg/github/__toolsnaps__/list_commits.snap index 6603bdf5..1e769c71 100644 --- a/pkg/github/__toolsnaps__/list_commits.snap +++ b/pkg/github/__toolsnaps__/list_commits.snap @@ -6,6 +6,10 @@ "description": "Get list of commits of a branch in a GitHub repository", "inputSchema": { "properties": { + "author": { + "description": "Author username or email address", + "type": "string" + }, "owner": { "description": "Repository owner", "type": "string" @@ -28,10 +32,6 @@ "sha": { "description": "SHA or Branch name", "type": "string" - }, - "author": { - "description": "Author username or email address", - "type": "string" } }, "required": [ diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 527a426e..cf33fb5a 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -644,7 +645,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo Filter: "latest", }) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to list workflow jobs: %v", err)), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow jobs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -670,7 +671,7 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo // Collect logs for all failed jobs var logResults []map[string]any for _, job := range failedJobs { - jobResult, err := getJobLogData(ctx, client, owner, repo, job.GetID(), job.GetName(), returnContent) + jobResult, resp, err := getJobLogData(ctx, client, owner, repo, job.GetID(), job.GetName(), returnContent) if err != nil { // Continue with other jobs even if one fails jobResult = map[string]any{ @@ -678,7 +679,10 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo "job_name": job.GetName(), "error": err.Error(), } + // Enable reporting of status codes and error causes + _, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get job logs", resp, err) // Explicitly ignore error for graceful handling } + logResults = append(logResults, jobResult) } @@ -701,9 +705,9 @@ func handleFailedJobLogs(ctx context.Context, client *github.Client, owner, repo // handleSingleJobLogs gets logs for a single job func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo string, jobID int64, returnContent bool) (*mcp.CallToolResult, error) { - jobResult, err := getJobLogData(ctx, client, owner, repo, jobID, "", returnContent) + jobResult, resp, err := getJobLogData(ctx, client, owner, repo, jobID, "", returnContent) if err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get job logs", resp, err), nil } r, err := json.Marshal(jobResult) @@ -715,11 +719,11 @@ func handleSingleJobLogs(ctx context.Context, client *github.Client, owner, repo } // getJobLogData retrieves log data for a single job, either as URL or content -func getJobLogData(ctx context.Context, client *github.Client, owner, repo string, jobID int64, jobName string, returnContent bool) (map[string]any, error) { +func getJobLogData(ctx context.Context, client *github.Client, owner, repo string, jobID int64, jobName string, returnContent bool) (map[string]any, *github.Response, error) { // Get the download URL for the job logs url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, owner, repo, jobID, 1) if err != nil { - return nil, fmt.Errorf("failed to get job logs for job %d: %w", jobID, err) + return nil, resp, fmt.Errorf("failed to get job logs for job %d: %w", jobID, err) } defer func() { _ = resp.Body.Close() }() @@ -732,9 +736,13 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin if returnContent { // Download and return the actual log content - content, err := downloadLogContent(url.String()) + content, httpResp, err := downloadLogContent(url.String()) //nolint:bodyclose // Response body is closed in downloadLogContent, but we need to return httpResp if err != nil { - return nil, fmt.Errorf("failed to download log content for job %d: %w", jobID, err) + // To keep the return value consistent wrap the response as a GitHub Response + ghRes := &github.Response{ + Response: httpResp, + } + return nil, ghRes, fmt.Errorf("failed to download log content for job %d: %w", jobID, err) } result["logs_content"] = content result["message"] = "Job logs content retrieved successfully" @@ -745,29 +753,29 @@ func getJobLogData(ctx context.Context, client *github.Client, owner, repo strin result["note"] = "The logs_url provides a download link for the individual job logs in plain text format. Use return_content=true to get the actual log content." } - return result, nil + return result, resp, nil } // downloadLogContent downloads the actual log content from a GitHub logs URL -func downloadLogContent(logURL string) (string, error) { +func downloadLogContent(logURL string) (string, *http.Response, error) { httpResp, err := http.Get(logURL) //nolint:gosec // URLs are provided by GitHub API and are safe if err != nil { - return "", fmt.Errorf("failed to download logs: %w", err) + return "", httpResp, fmt.Errorf("failed to download logs: %w", err) } defer func() { _ = httpResp.Body.Close() }() if httpResp.StatusCode != http.StatusOK { - return "", fmt.Errorf("failed to download logs: HTTP %d", httpResp.StatusCode) + return "", httpResp, fmt.Errorf("failed to download logs: HTTP %d", httpResp.StatusCode) } content, err := io.ReadAll(httpResp.Body) if err != nil { - return "", fmt.Errorf("failed to read log content: %w", err) + return "", httpResp, fmt.Errorf("failed to read log content: %w", err) } // Clean up and format the log content for better readability logContent := strings.TrimSpace(string(content)) - return logContent, nil + return logContent, httpResp, nil } // RerunWorkflowRun creates a tool to re-run an entire workflow run @@ -813,7 +821,7 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to rerun workflow run: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -876,7 +884,7 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to rerun failed jobs: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -939,7 +947,7 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to cancel workflow run: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1024,7 +1032,7 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, runID, opts) if err != nil { - return nil, fmt.Errorf("failed to list workflow run artifacts: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1081,7 +1089,7 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati // Get the download URL for the artifact url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) if err != nil { - return nil, fmt.Errorf("failed to get artifact download URL: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1146,7 +1154,7 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to delete workflow run logs: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil } defer func() { _ = resp.Body.Close() }() @@ -1209,7 +1217,7 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) if err != nil { - return nil, fmt.Errorf("failed to get workflow run usage: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 98714b6c..3b07692c 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -54,7 +55,11 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, fmt.Errorf("failed to get alert: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get alert", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -138,7 +143,11 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName}) if err != nil { - return nil, fmt.Errorf("failed to list alerts: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list alerts", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 5c0131a7..bd76ccba 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -94,12 +94,15 @@ func Test_GetCodeScanningAlert(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -217,12 +220,15 @@ func Test_ListCodeScanningAlerts(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 62a953de..bed2f4a3 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -3,6 +3,7 @@ package github import ( "context" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -28,9 +29,13 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too return mcp.NewToolResultErrorFromErr("failed to get GitHub client", err), nil } - user, _, err := client.Users.Get(ctx, "") + user, res, err := client.Users.Get(ctx, "") if err != nil { - return mcp.NewToolResultErrorFromErr("failed to get user", err), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get user", + res, + err, + ), nil } return MarshalledTextResult(user), nil diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 677ee99f..b6b6bfd7 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -118,7 +119,11 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu notifications, resp, err = client.Activity.ListNotifications(ctx, opts) } if err != nil { - return nil, fmt.Errorf("failed to get notifications: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list notifications", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -187,7 +192,11 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper } if err != nil { - return nil, fmt.Errorf("failed to mark notification as %s: %w", state, err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to mark notification as %s", state), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -262,7 +271,11 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH resp, err = client.Activity.MarkNotificationsRead(ctx, markReadOptions) } if err != nil { - return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to mark all notifications as read", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -304,7 +317,11 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel thread, resp, err := client.Activity.GetThread(ctx, notificationID) if err != nil { - return nil, fmt.Errorf("failed to get notification details: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get notification details for ID '%s'", notificationID), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -385,7 +402,11 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl } if apiErr != nil { - return nil, fmt.Errorf("failed to %s notification subscription: %w", action, apiErr) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s notification subscription", action), + resp, + apiErr, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -474,7 +495,11 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati } if apiErr != nil { - return nil, fmt.Errorf("failed to %s repository subscription: %w", action, apiErr) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to %s repository subscription", action), + resp, + apiErr, + ), nil } if resp != nil { defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 77372f02..a83df3ed 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -127,14 +127,17 @@ func Test_ListNotifications(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) t.Logf("textContent: %s", textContent.Text) var returned []*github.Notification @@ -663,14 +666,17 @@ func Test_MarkAllNotificationsRead(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) if tc.expectMarked { assert.Contains(t, textContent.Text, "All notifications marked as read") @@ -738,14 +744,17 @@ func Test_GetNotificationDetails(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { - assert.Contains(t, err.Error(), tc.expectedErrMsg) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) } return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result) var returned github.Notification err = json.Unmarshal([]byte(textContent.Text), &returned) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index b16920aa..7dcc2c4f 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -13,6 +13,7 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/shurcooL/githubv4" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" ) @@ -57,7 +58,11 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -172,7 +177,11 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { - return nil, fmt.Errorf("failed to create pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create pull request", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -293,7 +302,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -402,7 +415,11 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list pull requests: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list pull requests", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -491,7 +508,11 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { - return nil, fmt.Errorf("failed to merge pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to merge pull request", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -554,7 +575,11 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, fmt.Errorf("failed to get pull request files: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request files", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -616,7 +641,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -631,7 +660,11 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe // Get combined status for the head SHA status, resp, err := client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return nil, fmt.Errorf("failed to get combined status: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get combined status", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -709,7 +742,11 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Pull request branch update is in progress"), nil } - return nil, fmt.Errorf("failed to update pull request branch: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request branch", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -777,7 +814,11 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { - return nil, fmt.Errorf("failed to get pull request comments: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request comments", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -839,7 +880,11 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { - return nil, fmt.Errorf("failed to get pull request reviews: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request reviews", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -926,7 +971,10 @@ func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translation "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get pull request", + err, + ), nil } // Now we have the GQL ID, we can create a review @@ -1017,7 +1065,10 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. "repo": githubv4.String(params.Repo), "prNum": githubv4.Int(params.PullNumber), }); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get pull request", + err, + ), nil } // Now we have the GQL ID, we can create a pending review @@ -1135,7 +1186,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get current user", + err, + ), nil } var getLatestReviewForViewerQuery struct { @@ -1160,7 +1214,10 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get latest review for current user", + err, + ), nil } // Validate there is one review and the state is pending @@ -1266,7 +1323,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get current user", + err, + ), nil } var getLatestReviewForViewerQuery struct { @@ -1291,7 +1351,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get latest review for current user", + err, + ), nil } // Validate there is one review and the state is pending @@ -1324,7 +1387,10 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. }, nil, ); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to submit pull request review", + err, + ), nil } // Return nothing interesting, just indicate success for the time being. @@ -1381,7 +1447,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(ctx, &getViewerQuery, nil); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get current user", + err, + ), nil } var getLatestReviewForViewerQuery struct { @@ -1406,7 +1475,10 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, + "failed to get latest review for current user", + err, + ), nil } // Validate there is one review and the state is pending @@ -1490,7 +1562,11 @@ func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperF github.RawOptions{Type: github.Diff}, ) if err != nil { - return mcp.NewToolResultError(err.Error()), nil + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request diff", + resp, + err, + ), nil } if resp.StatusCode != http.StatusOK { @@ -1563,7 +1639,11 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe }, ) if err != nil { - return nil, fmt.Errorf("failed to request copilot review: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request copilot review", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 144c6b38..02575c43 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -109,12 +109,15 @@ func Test_GetPullRequest(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -272,23 +275,22 @@ func Test_UpdatePullRequest(t *testing.T) { result, err := handler(context.Background(), request) // Verify results - if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content textContent := getTextResult(t, result) - // Check for expected error message within the result text - if tc.expectedErrMsg != "" { - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - return - } - // Unmarshal and verify the successful result var returnedPR github.PullRequest err = json.Unmarshal([]byte(textContent.Text), &returnedPR) @@ -420,12 +422,15 @@ func Test_ListPullRequests(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -536,12 +541,15 @@ func Test_MergePullRequest(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -649,12 +657,15 @@ func Test_GetPullRequestFiles(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -810,12 +821,15 @@ func Test_GetPullRequestStatus(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -938,12 +952,15 @@ func Test_UpdatePullRequestBranch(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1055,12 +1072,15 @@ func Test_GetPullRequestComments(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1179,12 +1199,15 @@ func Test_GetPullRequestReviews(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1653,12 +1676,15 @@ func Test_RequestCopilotReview(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) assert.Len(t, result.Content, 1) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index eafb71ac..fa5d7338 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" @@ -68,7 +69,11 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too } commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) if err != nil { - return nil, fmt.Errorf("failed to get commit: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get commit: %s", sha), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -150,7 +155,11 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list commits: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list commits: %s", sha), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -217,7 +226,11 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list branches: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list branches", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -326,7 +339,11 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { - return nil, fmt.Errorf("failed to create/update file: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create/update file", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -400,7 +417,11 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { - return nil, fmt.Errorf("failed to create repository: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create repository", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -641,7 +662,11 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { return mcp.NewToolResultText("Fork is in progress"), nil } - return nil, fmt.Errorf("failed to fork repository: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to fork repository", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -734,7 +759,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get base commit: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -759,7 +788,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Create a new tree with the deletion newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) if err != nil { - return nil, fmt.Errorf("failed to create tree: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -779,7 +812,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, fmt.Errorf("failed to create commit: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -795,7 +832,11 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to ref.Object.SHA = newCommit.SHA _, resp, err = client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, fmt.Errorf("failed to update reference: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -876,7 +917,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get default branch if from_branch not specified repository, resp, err := client.Repositories.Get(ctx, owner, repo) if err != nil { - return nil, fmt.Errorf("failed to get repository: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get repository", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -886,7 +931,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( // Get SHA of source branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) if err != nil { - return nil, fmt.Errorf("failed to get reference: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get reference", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -898,7 +947,11 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) if err != nil { - return nil, fmt.Errorf("failed to create branch: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create branch", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -988,14 +1041,22 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Get the reference for the branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) if err != nil { - return nil, fmt.Errorf("failed to get branch reference: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get branch reference", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() // Get the commit object that the branch points to baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get base commit: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1030,7 +1091,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too // Create a new tree with the file entries newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) if err != nil { - return nil, fmt.Errorf("failed to create tree: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1042,7 +1107,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too } newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, fmt.Errorf("failed to create commit: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1050,7 +1119,11 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too ref.Object.SHA = newCommit.SHA updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, ref, false) if err != nil { - return nil, fmt.Errorf("failed to update reference: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1107,7 +1180,11 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) if err != nil { - return nil, fmt.Errorf("failed to list tags: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list tags", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1171,7 +1248,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // First get the tag reference ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) if err != nil { - return nil, fmt.Errorf("failed to get tag reference: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag reference", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -1186,7 +1267,11 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m // Then get the tag object tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, fmt.Errorf("failed to get tag object: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag object", + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 7ce2fec1..b621cec4 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -302,12 +302,15 @@ func Test_ForkRepository(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -489,12 +492,15 @@ func Test_CreateBranch(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -612,12 +618,15 @@ func Test_GetCommit(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -788,12 +797,15 @@ func Test_ListCommits(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -951,12 +963,15 @@ func Test_CreateOrUpdateFile(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1100,12 +1115,15 @@ func Test_CreateRepository(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1434,19 +1452,23 @@ func Test_PushFiles(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } if tc.expectedErrMsg != "" { require.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1847,12 +1869,15 @@ func Test_ListTags(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -1998,12 +2023,15 @@ func Test_GetTag(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) diff --git a/pkg/github/search.go b/pkg/github/search.go index 157675c1..13d01712 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -49,7 +50,11 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } result, resp, err := client.Search.Repositories(ctx, query, opts) if err != nil { - return nil, fmt.Errorf("failed to search repositories: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search repositories with query '%s'", query), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -125,7 +130,11 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to result, resp, err := client.Search.Code(ctx, query, opts) if err != nil { - return nil, fmt.Errorf("failed to search code: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search code with query '%s'", query), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -215,7 +224,11 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t result, resp, err := client.Search.Users(ctx, "type:user "+query, opts) if err != nil { - return nil, fmt.Errorf("failed to search users: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search users with query '%s'", query), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index b76fe804..f206ebb4 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -134,12 +134,15 @@ func Test_SearchRepositories(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -286,12 +289,15 @@ func Test_SearchCode(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -437,12 +443,15 @@ func Test_SearchUsers(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error require.NotNil(t, result) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index ec0eb15a..bea6df2a 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -7,6 +7,7 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v72/github" "github.com/mark3labs/mcp-go/mcp" @@ -55,7 +56,11 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel alert, resp, err := client.SecretScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { - return nil, fmt.Errorf("failed to get alert: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() @@ -132,7 +137,11 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH } alerts, resp, err := client.SecretScanning.ListAlertsForRepo(ctx, owner, repo, &github.SecretScanningAlertListOptions{State: state, SecretType: secretType, Resolution: resolution}) if err != nil { - return nil, fmt.Errorf("failed to list alerts: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil } defer func() { _ = resp.Body.Close() }() diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 4ec5539e..38b573e0 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -90,12 +90,15 @@ func Test_GetSecretScanningAlert(t *testing.T) { // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -217,12 +220,15 @@ func Test_ListSecretScanningAlerts(t *testing.T) { result, err := handler(context.Background(), request) if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } require.NoError(t, err) + require.False(t, result.IsError) textContent := getTextResult(t, result)