package github import ( "bytes" "context" "encoding/json" "io" "net/http" "net/url" "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" testifymock "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) // GitHub API endpoint patterns for testing // These constants define the URL patterns used in HTTP mocking for tests const ( // User endpoints GetUser = "GET /user" GetUserStarred = "GET /user/starred" GetUsersGistsByUsername = "GET /users/{username}/gists" GetUsersStarredByUsername = "GET /users/{username}/starred" PutUserStarredByOwnerByRepo = "PUT /user/starred/{owner}/{repo}" DeleteUserStarredByOwnerByRepo = "DELETE /user/starred/{owner}/{repo}" // Repository endpoints GetReposByOwnerByRepo = "GET /repos/{owner}/{repo}" GetReposBranchesByOwnerByRepo = "GET /repos/{owner}/{repo}/branches" GetReposTagsByOwnerByRepo = "GET /repos/{owner}/{repo}/tags" GetReposCommitsByOwnerByRepo = "GET /repos/{owner}/{repo}/commits" GetReposCommitsByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}" GetReposContentsByOwnerByRepoByPath = "GET /repos/{owner}/{repo}/contents/{path}" PutReposContentsByOwnerByRepoByPath = "PUT /repos/{owner}/{repo}/contents/{path}" PostReposForksByOwnerByRepo = "POST /repos/{owner}/{repo}/forks" GetReposSubscriptionByOwnerByRepo = "GET /repos/{owner}/{repo}/subscription" PutReposSubscriptionByOwnerByRepo = "PUT /repos/{owner}/{repo}/subscription" DeleteReposSubscriptionByOwnerByRepo = "DELETE /repos/{owner}/{repo}/subscription" // Git endpoints GetReposGitTreesByOwnerByRepoByTree = "GET /repos/{owner}/{repo}/git/trees/{tree}" GetReposGitRefByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/git/ref/{ref:.*}" PostReposGitRefsByOwnerByRepo = "POST /repos/{owner}/{repo}/git/refs" PatchReposGitRefsByOwnerByRepoByRef = "PATCH /repos/{owner}/{repo}/git/refs/{ref:.*}" GetReposGitCommitsByOwnerByRepoByCommitSHA = "GET /repos/{owner}/{repo}/git/commits/{commit_sha}" PostReposGitCommitsByOwnerByRepo = "POST /repos/{owner}/{repo}/git/commits" GetReposGitTagsByOwnerByRepoByTagSHA = "GET /repos/{owner}/{repo}/git/tags/{tag_sha}" PostReposGitTreesByOwnerByRepo = "POST /repos/{owner}/{repo}/git/trees" GetReposCommitsStatusByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}/status" GetReposCommitsStatusesByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}/statuses" GetReposCommitsCheckRunsByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}/check-runs" // Issues endpoints GetReposIssuesByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}" GetReposIssuesCommentsByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}/comments" PostReposIssuesByOwnerByRepo = "POST /repos/{owner}/{repo}/issues" PostReposIssuesCommentsByOwnerByRepoByIssueNumber = "POST /repos/{owner}/{repo}/issues/{issue_number}/comments" PatchReposIssuesByOwnerByRepoByIssueNumber = "PATCH /repos/{owner}/{repo}/issues/{issue_number}" GetReposIssuesSubIssuesByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}/sub_issues" PostReposIssuesSubIssuesByOwnerByRepoByIssueNumber = "POST /repos/{owner}/{repo}/issues/{issue_number}/sub_issues" DeleteReposIssuesSubIssueByOwnerByRepoByIssueNumber = "DELETE /repos/{owner}/{repo}/issues/{issue_number}/sub_issue" PatchReposIssuesSubIssuesPriorityByOwnerByRepoByIssueNumber = "PATCH /repos/{owner}/{repo}/issues/{issue_number}/sub_issues/priority" // Pull request endpoints GetReposPullsByOwnerByRepo = "GET /repos/{owner}/{repo}/pulls" GetReposPullsByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}" GetReposPullsFilesByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}/files" GetReposPullsReviewsByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}/reviews" PostReposPullsByOwnerByRepo = "POST /repos/{owner}/{repo}/pulls" PatchReposPullsByOwnerByRepoByPullNumber = "PATCH /repos/{owner}/{repo}/pulls/{pull_number}" PutReposPullsMergeByOwnerByRepoByPullNumber = "PUT /repos/{owner}/{repo}/pulls/{pull_number}/merge" PutReposPullsUpdateBranchByOwnerByRepoByPullNumber = "PUT /repos/{owner}/{repo}/pulls/{pull_number}/update-branch" PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber = "POST /repos/{owner}/{repo}/pulls/{pull_number}/requested_reviewers" PostReposPullsCommentsByOwnerByRepoByPullNumber = "POST /repos/{owner}/{repo}/pulls/{pull_number}/comments" // Notifications endpoints GetNotifications = "GET /notifications" PutNotifications = "PUT /notifications" GetReposNotificationsByOwnerByRepo = "GET /repos/{owner}/{repo}/notifications" PutReposNotificationsByOwnerByRepo = "PUT /repos/{owner}/{repo}/notifications" GetNotificationsThreadsByThreadID = "GET /notifications/threads/{thread_id}" PatchNotificationsThreadsByThreadID = "PATCH /notifications/threads/{thread_id}" DeleteNotificationsThreadsByThreadID = "DELETE /notifications/threads/{thread_id}" PutNotificationsThreadsSubscriptionByThreadID = "PUT /notifications/threads/{thread_id}/subscription" DeleteNotificationsThreadsSubscriptionByThreadID = "DELETE /notifications/threads/{thread_id}/subscription" // Gists endpoints GetGists = "GET /gists" GetGistsByGistID = "GET /gists/{gist_id}" PostGists = "POST /gists" PatchGistsByGistID = "PATCH /gists/{gist_id}" // Releases endpoints GetReposReleasesByOwnerByRepo = "GET /repos/{owner}/{repo}/releases" GetReposReleasesLatestByOwnerByRepo = "GET /repos/{owner}/{repo}/releases/latest" GetReposReleasesTagsByOwnerByRepoByTag = "GET /repos/{owner}/{repo}/releases/tags/{tag}" // Code scanning endpoints GetReposCodeScanningAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/code-scanning/alerts" GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/code-scanning/alerts/{alert_number}" // Secret scanning endpoints GetReposSecretScanningAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/secret-scanning/alerts" //nolint:gosec // False positive - this is an API endpoint pattern, not a credential GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/secret-scanning/alerts/{alert_number}" //nolint:gosec // False positive - this is an API endpoint pattern, not a credential // Dependabot endpoints GetReposDependabotAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/dependabot/alerts" GetReposDependabotAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/dependabot/alerts/{alert_number}" // Security advisories endpoints GetAdvisories = "GET /advisories" GetAdvisoriesByGhsaID = "GET /advisories/{ghsa_id}" GetReposSecurityAdvisoriesByOwnerByRepo = "GET /repos/{owner}/{repo}/security-advisories" GetOrgsSecurityAdvisoriesByOrg = "GET /orgs/{org}/security-advisories" // Actions endpoints GetReposActionsWorkflowsByOwnerByRepo = "GET /repos/{owner}/{repo}/actions/workflows" GetReposActionsWorkflowsByOwnerByRepoByWorkflowID = "GET /repos/{owner}/{repo}/actions/workflows/{workflow_id}" PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowID = "POST /repos/{owner}/{repo}/actions/workflows/{workflow_id}/dispatches" GetReposActionsWorkflowsRunsByOwnerByRepoByWorkflowID = "GET /repos/{owner}/{repo}/actions/workflows/{workflow_id}/runs" GetReposActionsRunsByOwnerByRepo = "GET /repos/{owner}/{repo}/actions/runs" GetReposActionsRunsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}" GetReposActionsRunsLogsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/logs" GetReposActionsRunsJobsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/jobs" GetReposActionsRunsArtifactsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/artifacts" GetReposActionsRunsTimingByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/timing" PostReposActionsRunsRerunByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/rerun" PostReposActionsRunsRerunFailedJobsByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/rerun-failed-jobs" PostReposActionsRunsCancelByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel" GetReposActionsJobsLogsByOwnerByRepoByJobID = "GET /repos/{owner}/{repo}/actions/jobs/{job_id}/logs" DeleteReposActionsRunsLogsByOwnerByRepoByRunID = "DELETE /repos/{owner}/{repo}/actions/runs/{run_id}/logs" // Search endpoints GetSearchCode = "GET /search/code" GetSearchIssues = "GET /search/issues" GetSearchUsers = "GET /search/users" GetSearchRepositories = "GET /search/repositories" // Raw content endpoints (used for GitHub raw content API, not standard API) // These are used with the raw content client that interacts with raw.githubusercontent.com GetRawReposContentsByOwnerByRepoByPath = "GET /{owner}/{repo}/HEAD/{path:.*}" GetRawReposContentsByOwnerByRepoByBranchByPath = "GET /{owner}/{repo}/refs/heads/{branch}/{path:.*}" GetRawReposContentsByOwnerByRepoByTagByPath = "GET /{owner}/{repo}/refs/tags/{tag}/{path:.*}" GetRawReposContentsByOwnerByRepoBySHAByPath = "GET /{owner}/{repo}/{sha}/{path:.*}" // Projects (ProjectsV2) endpoints // Organization-scoped GetOrgsProjectsV2 = "GET /orgs/{org}/projectsV2" GetOrgsProjectsV2ByProject = "GET /orgs/{org}/projectsV2/{project}" GetOrgsProjectsV2FieldsByProject = "GET /orgs/{org}/projectsV2/{project}/fields" GetOrgsProjectsV2FieldsByProjectByFieldID = "GET /orgs/{org}/projectsV2/{project}/fields/{field_id}" GetOrgsProjectsV2ItemsByProject = "GET /orgs/{org}/projectsV2/{project}/items" GetOrgsProjectsV2ItemsByProjectByItemID = "GET /orgs/{org}/projectsV2/{project}/items/{item_id}" PostOrgsProjectsV2ItemsByProject = "POST /orgs/{org}/projectsV2/{project}/items" PatchOrgsProjectsV2ItemsByProjectByItemID = "PATCH /orgs/{org}/projectsV2/{project}/items/{item_id}" DeleteOrgsProjectsV2ItemsByProjectByItemID = "DELETE /orgs/{org}/projectsV2/{project}/items/{item_id}" // User-scoped GetUsersProjectsV2ByUsername = "GET /users/{username}/projectsV2" GetUsersProjectsV2ByUsernameByProject = "GET /users/{username}/projectsV2/{project}" GetUsersProjectsV2FieldsByUsernameByProject = "GET /users/{username}/projectsV2/{project}/fields" GetUsersProjectsV2FieldsByUsernameByProjectByFieldID = "GET /users/{username}/projectsV2/{project}/fields/{field_id}" GetUsersProjectsV2ItemsByUsernameByProject = "GET /users/{username}/projectsV2/{project}/items" GetUsersProjectsV2ItemsByUsernameByProjectByItemID = "GET /users/{username}/projectsV2/{project}/items/{item_id}" PostUsersProjectsV2ItemsByUsernameByProject = "POST /users/{username}/projectsV2/{project}/items" PatchUsersProjectsV2ItemsByUsernameByProjectByItemID = "PATCH /users/{username}/projectsV2/{project}/items/{item_id}" DeleteUsersProjectsV2ItemsByUsernameByProjectByItemID = "DELETE /users/{username}/projectsV2/{project}/items/{item_id}" // Organization issue types endpoints GetOrgsIssueTypesByOrg = "GET /orgs/{org}/issue-types" ) type expectations struct { path string queryParams map[string]string requestBody any } // expect is a helper function to create a partial mock that expects various // request behaviors, such as path, query parameters, and request body. func expect(t *testing.T, e expectations) *partialMock { return &partialMock{ t: t, expectedPath: e.path, expectedQueryParams: e.queryParams, expectedRequestBody: e.requestBody, } } // expectPath is a helper function to create a partial mock that expects a // request with the given path, with the ability to chain a response handler. func expectPath(t *testing.T, expectedPath string) *partialMock { return &partialMock{ t: t, expectedPath: expectedPath, } } // expectQueryParams is a helper function to create a partial mock that expects a // request with the given query parameters, with the ability to chain a response handler. func expectQueryParams(t *testing.T, expectedQueryParams map[string]string) *partialMock { return &partialMock{ t: t, expectedQueryParams: expectedQueryParams, } } // expectRequestBody is a helper function to create a partial mock that expects a // request with the given body, with the ability to chain a response handler. func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock { return &partialMock{ t: t, expectedRequestBody: expectedRequestBody, } } type partialMock struct { t *testing.T expectedPath string expectedQueryParams map[string]string expectedRequestBody any } func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc { p.t.Helper() return func(w http.ResponseWriter, r *http.Request) { if p.expectedPath != "" { require.Equal(p.t, p.expectedPath, r.URL.Path) } if p.expectedQueryParams != nil { require.Equal(p.t, len(p.expectedQueryParams), len(r.URL.Query())) for k, v := range p.expectedQueryParams { require.Equal(p.t, v, r.URL.Query().Get(k)) } } if p.expectedRequestBody != nil { var unmarshaledRequestBody any err := json.NewDecoder(r.Body).Decode(&unmarshaledRequestBody) require.NoError(p.t, err) require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) } responseHandler(w, r) } } // mockResponse is a helper function to create a mock HTTP response handler // that returns a specified status code and marshaled body. func mockResponse(t *testing.T, code int, body any) http.HandlerFunc { t.Helper() return func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(code) // Some tests do not expect to return a JSON object, such as fetching a raw pull request diff, // so allow strings to be returned directly. s, ok := body.(string) if ok { _, _ = w.Write([]byte(s)) return } b, err := json.Marshal(body) require.NoError(t, err) _, _ = w.Write(b) } } // createMCPRequest is a helper function to create a MCP request with the given arguments. func createMCPRequest(args any) mcp.CallToolRequest { // convert args to map[string]interface{} and serialize to JSON argsMap, ok := args.(map[string]any) if !ok { argsMap = make(map[string]any) } argsJSON, err := json.Marshal(argsMap) if err != nil { return mcp.CallToolRequest{} } jsonRawMessage := json.RawMessage(argsJSON) return mcp.CallToolRequest{ Params: &mcp.CallToolParamsRaw{ Arguments: jsonRawMessage, }, } } // Well-known MCP client names used in tests. const ( ClientNameVSCodeInsiders = "Visual Studio Code - Insiders" ClientNameVSCode = "Visual Studio Code" ) // createMCPRequestWithSession creates a CallToolRequest with a ServerSession // that has the given client name in its InitializeParams. When withUI is true // the session advertises MCP Apps UI support via the capability extension. func createMCPRequestWithSession(t *testing.T, clientName string, withUI bool, args any) mcp.CallToolRequest { t.Helper() argsMap, ok := args.(map[string]any) if !ok { argsMap = make(map[string]any) } argsJSON, err := json.Marshal(argsMap) require.NoError(t, err) srv := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) caps := &mcp.ClientCapabilities{} if withUI { caps.AddExtension("io.modelcontextprotocol/ui", map[string]any{ "mimeTypes": []string{"text/html;profile=mcp-app"}, }) } st, _ := mcp.NewInMemoryTransports() session, err := srv.Connect(context.Background(), st, &mcp.ServerSessionOptions{ State: &mcp.ServerSessionState{ InitializeParams: &mcp.InitializeParams{ ClientInfo: &mcp.Implementation{Name: clientName}, Capabilities: caps, }, }, }) require.NoError(t, err) // Close the unused client-side transport and session t.Cleanup(func() { _ = session.Close() }) return mcp.CallToolRequest{ Session: session, Params: &mcp.CallToolParamsRaw{ Arguments: json.RawMessage(argsJSON), }, } } // getTextResult is a helper function that returns a text result from a tool call. func getTextResult(t *testing.T, result *mcp.CallToolResult) *mcp.TextContent { t.Helper() assert.NotNil(t, result) require.Len(t, result.Content, 1) textContent, ok := result.Content[0].(*mcp.TextContent) require.True(t, ok, "expected content to be of type TextContent") return textContent } func getErrorResult(t *testing.T, result *mcp.CallToolResult) *mcp.TextContent { res := getTextResult(t, result) require.True(t, result.IsError, "expected tool call result to be an error") return res } // getTextResourceResult is a helper function that returns a text result from a tool call. // getBlobResourceResult is a helper function that returns a blob result from a tool call. func TestOptionalParamOK(t *testing.T) { tests := []struct { name string args map[string]any paramName string expectedVal any expectedOk bool expectError bool errorMsg string }{ { name: "present and correct type (string)", args: map[string]any{"myParam": "hello"}, paramName: "myParam", expectedVal: "hello", expectedOk: true, expectError: false, }, { name: "present and correct type (bool)", args: map[string]any{"myParam": true}, paramName: "myParam", expectedVal: true, expectedOk: true, expectError: false, }, { name: "present and correct type (number)", args: map[string]any{"myParam": float64(123)}, paramName: "myParam", expectedVal: float64(123), expectedOk: true, expectError: false, }, { name: "present but wrong type (string expected, got bool)", args: map[string]any{"myParam": true}, paramName: "myParam", expectedVal: "", // Zero value for string expectedOk: true, // ok is true because param exists expectError: true, errorMsg: "parameter myParam is not of type string, is bool", }, { name: "present but wrong type (bool expected, got string)", args: map[string]any{"myParam": "true"}, paramName: "myParam", expectedVal: false, // Zero value for bool expectedOk: true, // ok is true because param exists expectError: true, errorMsg: "parameter myParam is not of type bool, is string", }, { name: "parameter not present", args: map[string]any{"anotherParam": "value"}, paramName: "myParam", expectedVal: "", // Zero value for string expectedOk: false, expectError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Test with string type assertion if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { val, ok, err := OptionalParamOK[string](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) assert.Equal(t, tc.expectedOk, ok) // Check ok even on error assert.Equal(t, tc.expectedVal, val) // Check zero value on error } else { require.NoError(t, err) assert.Equal(t, tc.expectedOk, ok) assert.Equal(t, tc.expectedVal, val) } } // Test with bool type assertion if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { val, ok, err := OptionalParamOK[bool](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) assert.Equal(t, tc.expectedOk, ok) // Check ok even on error assert.Equal(t, tc.expectedVal, val) // Check zero value on error } else { require.NoError(t, err) assert.Equal(t, tc.expectedOk, ok) assert.Equal(t, tc.expectedVal, val) } } // Test with float64 type assertion (for number case) if _, isFloat := tc.expectedVal.(float64); isFloat { val, ok, err := OptionalParamOK[float64](tc.args, tc.paramName) if tc.expectError { // This case shouldn't happen for float64 in the defined tests require.Fail(t, "Unexpected error case for float64") } else { require.NoError(t, err) assert.Equal(t, tc.expectedOk, ok) assert.Equal(t, tc.expectedVal, val) } } }) } } func getResourceResult(t *testing.T, result *mcp.CallToolResult) *mcp.ResourceContents { t.Helper() assert.NotNil(t, result) require.Len(t, result.Content, 2) content := result.Content[1] require.IsType(t, &mcp.EmbeddedResource{}, content) resource, ok := content.(*mcp.EmbeddedResource) require.True(t, ok, "expected content to be of type EmbeddedResource") require.IsType(t, &mcp.ResourceContents{}, resource.Resource) return resource.Resource } // MockRoundTripper is a mock HTTP transport using testify/mock type MockRoundTripper struct { testifymock.Mock handlers map[string]http.HandlerFunc } // NewMockRoundTripper creates a new mock round tripper func NewMockRoundTripper() *MockRoundTripper { return &MockRoundTripper{ handlers: make(map[string]http.HandlerFunc), } } // RoundTrip implements the http.RoundTripper interface func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // Normalize the request path and method for matching key := req.Method + " " + req.URL.Path // Check if we have a specific handler for this request if handler, ok := m.handlers[key]; ok { // Use httptest.ResponseRecorder to capture the handler's response recorder := &responseRecorder{ header: make(http.Header), body: &bytes.Buffer{}, } handler(recorder, req) return &http.Response{ StatusCode: recorder.statusCode, Header: recorder.header, Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), Request: req, }, nil } // Fall back to mock.Mock assertions if defined args := m.Called(req) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*http.Response), args.Error(1) } // On registers an expectation using testify/mock func (m *MockRoundTripper) OnRequest(method, path string, handler http.HandlerFunc) *MockRoundTripper { key := method + " " + path m.handlers[key] = handler return m } // NewMockHTTPClient creates an HTTP client with a mock transport func NewMockHTTPClient() (*http.Client, *MockRoundTripper) { transport := NewMockRoundTripper() client := &http.Client{Transport: transport} return client, transport } // responseRecorder is a simple response recorder for the mock transport type responseRecorder struct { statusCode int header http.Header body *bytes.Buffer } func (r *responseRecorder) Header() http.Header { return r.header } func (r *responseRecorder) Write(data []byte) (int, error) { if r.statusCode == 0 { r.statusCode = http.StatusOK } return r.body.Write(data) } func (r *responseRecorder) WriteHeader(statusCode int) { r.statusCode = statusCode } // matchPath checks if a request path matches a pattern (supports simple wildcards) func matchPath(pattern, path string) bool { // Simple exact match for now if pattern == path { return true } // Support for path parameters like /repos/{owner}/{repo}/issues/{issue_number} patternParts := strings.Split(strings.Trim(pattern, "/"), "/") pathParts := strings.Split(strings.Trim(path, "/"), "/") // Handle patterns with wildcard path like {path:.*} if len(patternParts) > 0 { lastPart := patternParts[len(patternParts)-1] if strings.HasPrefix(lastPart, "{") && strings.Contains(lastPart, ":") && strings.HasSuffix(lastPart, "}") { // This is a wildcard pattern like {path:.*} // Check if all parts before the wildcard match if len(pathParts) < len(patternParts)-1 { return false } for i := range len(patternParts) - 1 { if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { continue // Path parameter matches anything } if patternParts[i] != pathParts[i] { return false } } return true } } if len(patternParts) != len(pathParts) { return false } for i := range patternParts { // Check if this is a path parameter (enclosed in {}) if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { continue // Path parameters match anything } if patternParts[i] != pathParts[i] { return false } } return true } // executeHandler executes an HTTP handler and returns the response func executeHandler(handler http.HandlerFunc, req *http.Request) *http.Response { recorder := &responseRecorder{ header: make(http.Header), body: &bytes.Buffer{}, } handler(recorder, req) return &http.Response{ StatusCode: recorder.statusCode, Header: recorder.header, Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), Request: req, } } // MockHTTPClientWithHandler creates an HTTP client with a single handler function func MockHTTPClientWithHandler(handler http.HandlerFunc) *http.Client { handlers := map[string]http.HandlerFunc{ "": handler, // Empty key acts as catch-all } return MockHTTPClientWithHandlers(handlers) } // MockHTTPClientWithHandlers creates an HTTP client with multiple handlers for different paths func MockHTTPClientWithHandlers(handlers map[string]http.HandlerFunc) *http.Client { transport := &multiHandlerTransport{handlers: handlers} return &http.Client{Transport: transport} } // Compatibility helpers to replace github.com/migueleliasweb/go-github-mock in tests type EndpointPattern string type MockBackendOption func(map[string]http.HandlerFunc) func parseEndpointPattern(p EndpointPattern) (string, string) { parts := strings.SplitN(string(p), " ", 2) if len(parts) != 2 { return http.MethodGet, string(p) } return parts[0], parts[1] } func WithRequestMatch(pattern EndpointPattern, response any) MockBackendOption { return func(handlers map[string]http.HandlerFunc) { method, path := parseEndpointPattern(pattern) handlers[method+" "+path] = func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) switch v := response.(type) { case string: _, _ = w.Write([]byte(v)) case []byte: _, _ = w.Write(v) default: data, err := json.Marshal(v) if err == nil { _, _ = w.Write(data) } } } } } func WithRequestMatchHandler(pattern EndpointPattern, handler http.HandlerFunc) MockBackendOption { return func(handlers map[string]http.HandlerFunc) { method, path := parseEndpointPattern(pattern) handlers[method+" "+path] = handler } } func NewMockedHTTPClient(options ...MockBackendOption) *http.Client { handlers := map[string]http.HandlerFunc{} for _, opt := range options { if opt != nil { opt(handlers) } } return MockHTTPClientWithHandlers(handlers) } func MustMarshal(v any) []byte { data, err := json.Marshal(v) if err != nil { panic(err) } return data } type multiHandlerTransport struct { handlers map[string]http.HandlerFunc } func (m *multiHandlerTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Check for catch-all handler if handler, ok := m.handlers[""]; ok { return executeHandler(handler, req), nil } // Try to find a handler for this request key := req.Method + " " + req.URL.Path // First try exact match if handler, ok := m.handlers[key]; ok { return executeHandler(handler, req), nil } // Then try pattern matching, prioritizing patterns without wildcards // This is important because wildcard patterns like /{owner}/{repo}/{sha}/{path:.*} // can incorrectly match API paths like /repos/owner/repo/pulls/42 var wildcardPattern string var wildcardHandler http.HandlerFunc for pattern, handler := range m.handlers { if pattern == "" { continue // Skip catch-all } parts := strings.SplitN(pattern, " ", 2) if len(parts) != 2 { continue } method, pathPattern := parts[0], parts[1] if req.Method != method { continue } // Check if this pattern contains a wildcard like {path:.*} isWildcard := strings.Contains(pathPattern, ":.*}") if matchPath(pathPattern, req.URL.Path) { if isWildcard { // Save wildcard match for later, prefer non-wildcard patterns wildcardPattern = pattern wildcardHandler = handler } else { // Non-wildcard pattern takes priority return executeHandler(handler, req), nil } } } // If we found a wildcard match but no specific match, use it if wildcardPattern != "" && wildcardHandler != nil { return executeHandler(wildcardHandler, req), nil } // No handler found return &http.Response{ StatusCode: http.StatusNotFound, Body: io.NopCloser(bytes.NewReader([]byte("not found"))), Request: req, }, nil } // extractPathParams extracts path parameters from a URL path given a pattern func extractPathParams(pattern, path string) map[string]string { params := make(map[string]string) patternParts := strings.Split(strings.Trim(pattern, "/"), "/") pathParts := strings.Split(strings.Trim(path, "/"), "/") if len(patternParts) != len(pathParts) { return params } for i := range patternParts { if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { paramName := strings.Trim(patternParts[i], "{}") params[paramName] = pathParts[i] } } return params } // ParseRequestPath is a helper to extract path parameters func ParseRequestPath(t *testing.T, req *http.Request, pattern string) url.Values { t.Helper() params := extractPathParams(pattern, req.URL.Path) values := url.Values{} for k, v := range params { values.Set(k, v) } return values }