From fc99b14fef6870071bd4db18bc429a8e7db5b727 Mon Sep 17 00:00:00 2001 From: Javier Uruen Val Date: Tue, 1 Apr 2025 16:40:16 +0200 Subject: [PATCH] add support for create_pull_request --- README.md | 11 +++ pkg/github/pullrequests.go | 107 +++++++++++++++++++++++ pkg/github/pullrequests_test.go | 149 ++++++++++++++++++++++++++++++++ pkg/github/server.go | 1 + 4 files changed, 268 insertions(+) diff --git a/README.md b/README.md index 1d7eba365..39a380377 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,17 @@ The flag `--gh-host` and the environment variable `GH_HOST` can be used to set t - `commit_id`: SHA of commit to review (string, optional) - `comments`: Line-specific comments array of objects, each object with path (string), position (number), and body (string) (array, optional) +- **create_pull_request** - Create a new pull request + + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `title`: PR title (string, required) + - `body`: PR description (string, optional) + - `head`: Branch containing changes (string, required) + - `base`: Branch to merge into (string, required) + - `draft`: Create as draft PR (boolean, optional) + - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + ### Repositories - **create_or_update_file** - Create or update a single file in a repository diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index fc77caded..ddec1e6ef 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -712,3 +712,110 @@ func createPullRequestReview(client *github.Client, t translations.TranslationHe return mcp.NewToolResultText(string(r)), nil } } + +// createPullRequest creates a tool to create a new pull request. +func createPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("create_pull_request", + mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("title", + mcp.Required(), + mcp.Description("PR title"), + ), + mcp.WithString("body", + mcp.Description("PR description"), + ), + mcp.WithString("head", + mcp.Required(), + mcp.Description("Branch containing changes"), + ), + mcp.WithString("base", + mcp.Required(), + mcp.Description("Branch to merge into"), + ), + mcp.WithBoolean("draft", + mcp.Description("Create as draft PR"), + ), + mcp.WithBoolean("maintainer_can_modify", + mcp.Description("Allow maintainer edits"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + head, err := requiredParam[string](request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + base, err := requiredParam[string](request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + body, err := optionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + draft, err := optionalParam[bool](request, "draft") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + maintainerCanModify, err := optionalParam[bool](request, "maintainer_can_modify") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + newPR := &github.NewPullRequest{ + Title: github.Ptr(title), + Head: github.Ptr(head), + Base: github.Ptr(base), + } + + if body != "" { + newPR.Body = github.Ptr(body) + } + + newPR.Draft = github.Ptr(draft) + newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + + pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + if err != nil { + return nil, fmt.Errorf("failed to create pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 6432c5710..30efe4f69 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -1187,3 +1187,152 @@ func Test_CreatePullRequestReview(t *testing.T) { }) } } + +func Test_CreatePullRequest(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := createPullRequest(mockClient, translations.NullTranslationHelper) + + assert.Equal(t, "create_pull_request", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "title") + assert.Contains(t, tool.InputSchema.Properties, "body") + assert.Contains(t, tool.InputSchema.Properties, "head") + assert.Contains(t, tool.InputSchema.Properties, "base") + assert.Contains(t, tool.InputSchema.Properties, "draft") + assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "title", "head", "base"}) + + // Setup mock PR for success case + mockPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Head: &github.PullRequestBranch{ + SHA: github.Ptr("abcd1234"), + Ref: github.Ptr("feature-branch"), + }, + Base: &github.PullRequestBranch{ + SHA: github.Ptr("efgh5678"), + Ref: github.Ptr("main"), + }, + Body: github.Ptr("This is a test PR"), + Draft: github.Ptr(false), + MaintainerCanModify: github.Ptr(true), + User: &github.User{ + Login: github.Ptr("testuser"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful PR creation", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposPullsByOwnerByRepo, + mockResponse(t, http.StatusCreated, mockPR), + ), + ), + + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "title": "Test PR", + "body": "This is a test PR", + "head": "feature-branch", + "base": "main", + "draft": false, + "maintainer_can_modify": true, + }, + expectError: false, + expectedPR: mockPR, + }, + { + name: "missing required parameter", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + // missing title, head, base + }, + expectError: true, + expectedErrMsg: "missing required parameter: title", + }, + { + name: "PR creation fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposPullsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message":"Validation failed","errors":[{"resource":"PullRequest","code":"invalid"}]}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "title": "Test PR", + "head": "feature-branch", + "base": "main", + }, + expectError: true, + expectedErrMsg: "failed to create pull request", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := createPullRequest(client, translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + // If no error returned but in the result + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + assert.Equal(t, *tc.expectedPR.HTMLURL, *returnedPR.HTMLURL) + assert.Equal(t, *tc.expectedPR.Head.SHA, *returnedPR.Head.SHA) + assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref) + assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body) + assert.Equal(t, *tc.expectedPR.User.Login, *returnedPR.User.Login) + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index c01e0918f..ce39c87e9 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -55,6 +55,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH s.AddTool(mergePullRequest(client, t)) s.AddTool(updatePullRequestBranch(client, t)) s.AddTool(createPullRequestReview(client, t)) + s.AddTool(createPullRequest(client, t)) } // Add GitHub tools - Repositories