From 3234755fa4c67a28011716576dff4664d38f6bb6 Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Tue, 15 Apr 2025 23:25:47 +0100 Subject: [PATCH 01/11] feat: add reviewers parameter to UpdatePullRequest and update tests --- README.md | 1 + pkg/github/pullrequests.go | 101 ++++++++++++++++++++++++++++---- pkg/github/pullrequests_test.go | 100 ++++++++++++++++++++++++++++++- 3 files changed, 190 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 288d7548b..98784b54b 100644 --- a/README.md +++ b/README.md @@ -327,6 +327,7 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `state`: New state ('open' or 'closed') (string, optional) - `base`: New base branch name (string, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `reviewers`: GitHub usernames to request reviews from (string[], optional) ### Repositories diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 2be249c8a..1c8b35aaa 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -103,6 +103,12 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.WithBoolean("maintainer_can_modify", mcp.Description("Allow maintainer edits"), ), + mcp.WithArray("reviewers", + mcp.Description("GitHub usernames to request reviews from"), + mcp.Items(map[string]interface{}{ + "type": "string", + }), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -157,26 +163,101 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu updateNeeded = true } - if !updateNeeded { - return mcp.NewToolResultError("No update parameters provided."), nil + // Handle reviewers separately + var reviewers []string + if reviewersArr, ok := request.Params.Arguments["reviewers"].([]interface{}); ok && len(reviewersArr) > 0 { + for _, reviewer := range reviewersArr { + if reviewerStr, ok := reviewer.(string); ok { + reviewers = append(reviewers, reviewerStr) + } + } } + // Create the GitHub client client, err := getClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) + + var pr *github.PullRequest + var resp *http.Response + + // First, update the PR if needed + if updateNeeded { + var ghResp *github.Response + pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return nil, fmt.Errorf("failed to update pull request: %w", err) + } + resp = ghResp.Response + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + } else { + // If no update needed, just get the current PR + var ghResp *github.Response + pr, ghResp, err = client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return nil, fmt.Errorf("failed to get pull request: %w", err) + } + resp = ghResp.Response + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + } } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Add reviewers if specified + if len(reviewers) > 0 { + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + // Use the direct result of RequestReviewers which includes the requested reviewers + updatedPR, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to request reviewers: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil + } + + // Use the updated PR with reviewers + pr = updatedPR + } + + // If no updates and no reviewers, return error + if !updateNeeded && len(reviewers) == 0 { + return mcp.NewToolResultError("No update parameters provided"), nil } r, err := json.Marshal(pr) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index bb3726249..3a064a399 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -141,6 +141,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "state") assert.Contains(t, tool.InputSchema.Properties, "base") assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.Contains(t, tool.InputSchema.Properties, "reviewers") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case @@ -162,6 +163,23 @@ func Test_UpdatePullRequest(t *testing.T) { State: github.Ptr("closed"), // State updated } + // Mock PR for when there are no updates but we still need a response + mockNoUpdatePR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + } + + mockPRWithReviewers := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + RequestedReviewers: []*github.User{ + {Login: github.Ptr("reviewer1")}, + {Login: github.Ptr("reviewer2")}, + }, + } + tests := []struct { name string mockedClient *http.Client @@ -220,8 +238,40 @@ func Test_UpdatePullRequest(t *testing.T) { expectedPR: mockClosedPR, }, { - name: "no update parameters provided", - mockedClient: mock.NewMockedHTTPClient(), // No API call expected + name: "successful PR update with reviewers", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + }, + ), + // Mock for RequestReviewers call, returning the PR with reviewers + mock.WithRequestMatch( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"reviewer1", "reviewer2"}, + }, + expectError: false, + expectedPR: mockPRWithReviewers, + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient( + // Mock a response for the GET PR request in case of no updates + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockNoUpdatePR, + ), + ), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -251,6 +301,32 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: true, expectedErrMsg: "failed to update pull request", }, + { + name: "request reviewers fails", + mockedClient: mock.NewMockedHTTPClient( + // First it gets the PR (no fields to update) + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockNoUpdatePR, + ), + // Then reviewer request fails + mock.WithRequestMatchHandler( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid reviewers"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"invalid-user"}, + }, + expectError: true, + expectedErrMsg: "failed to request reviewers", + }, } for _, tc := range tests { @@ -304,6 +380,26 @@ func Test_UpdatePullRequest(t *testing.T) { if tc.expectedPR.MaintainerCanModify != nil { assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) } + + // Check reviewers if they exist in the expected PR + if tc.expectedPR.RequestedReviewers != nil && len(tc.expectedPR.RequestedReviewers) > 0 { + assert.NotNil(t, returnedPR.RequestedReviewers) + assert.Equal(t, len(tc.expectedPR.RequestedReviewers), len(returnedPR.RequestedReviewers)) + + // Create maps of reviewer logins for easy comparison + expectedReviewers := make(map[string]bool) + for _, reviewer := range tc.expectedPR.RequestedReviewers { + expectedReviewers[*reviewer.Login] = true + } + + actualReviewers := make(map[string]bool) + for _, reviewer := range returnedPR.RequestedReviewers { + actualReviewers[*reviewer.Login] = true + } + + // Compare the maps + assert.Equal(t, expectedReviewers, actualReviewers) + } }) } } From 5c85a0940da7292f3c753edb34219437c8b4c3b4 Mon Sep 17 00:00:00 2001 From: MayorFaj <127399119+MayorFaj@users.noreply.github.com> Date: Wed, 25 Jun 2025 21:15:17 +0100 Subject: [PATCH 02/11] Update pullrequests.go --- pkg/github/pullrequests.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 6fb63fae3..caf4ef568 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -313,7 +313,6 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } -<<<<<<< feat/259/assign-reviewers var pr *github.PullRequest var resp *http.Response @@ -360,7 +359,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil } -======= + pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -368,7 +367,6 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu resp, err, ), nil ->>>>>>> main } // Add reviewers if specified From b09f5895e7c842fbd3c4ececc56d4882041191b2 Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Wed, 25 Jun 2025 21:47:25 +0100 Subject: [PATCH 03/11] feat: enhance update pull request functionality with reviewers support --- .../__toolsnaps__/update_pull_request.snap | 7 ++++ pkg/github/pullrequests.go | 36 +++++++++---------- pkg/github/pullrequests_test.go | 2 +- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index 765983afd..621299e43 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -30,6 +30,13 @@ "description": "Repository name", "type": "string" }, + "reviewers": { + "description": "GitHub usernames to request reviews from", + "items": { + "type": "string" + }, + "type": "array" + }, "state": { "description": "New state", "enum": [ diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index caf4ef568..f5be0b381 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -299,13 +299,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // Handle reviewers separately - var reviewers []string - if reviewersArr, ok := request.Params.Arguments["reviewers"].([]interface{}); ok && len(reviewersArr) > 0 { - for _, reviewer := range reviewersArr { - if reviewerStr, ok := reviewer.(string); ok { - reviewers = append(reviewers, reviewerStr) - } - } + reviewers, err := OptionalStringArrayParam(request, "reviewers") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } // Create the GitHub client @@ -322,7 +318,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu var ghResp *github.Response pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + ghResp, + err, + ), nil } resp = ghResp.Response defer func() { @@ -343,7 +343,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu var ghResp *github.Response pr, ghResp, err = client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to get pull request: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get pull request", + ghResp, + err, + ), nil } resp = ghResp.Response defer func() { @@ -359,14 +363,6 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil } - - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil } // Add reviewers if specified @@ -378,7 +374,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu // Use the direct result of RequestReviewers which includes the requested reviewers updatedPR, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) if err != nil { - return nil, fmt.Errorf("failed to request reviewers: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil } defer func() { if resp != nil && resp.Body != nil { diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 892fe1599..cd66460f6 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -390,7 +390,7 @@ func Test_UpdatePullRequest(t *testing.T) { } // Check reviewers if they exist in the expected PR - if tc.expectedPR.RequestedReviewers != nil && len(tc.expectedPR.RequestedReviewers) > 0 { + if len(tc.expectedPR.RequestedReviewers) > 0 { assert.NotNil(t, returnedPR.RequestedReviewers) assert.Equal(t, len(tc.expectedPR.RequestedReviewers), len(returnedPR.RequestedReviewers)) From 6f21c3fd3feef033f093195933dd843ad826712d Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Mon, 30 Jun 2025 22:04:52 +0100 Subject: [PATCH 04/11] update README to clarify optional reviewers parameter in API documentation- go run ./cmd/github-mcp-server generate-docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 39572e5f3..3b4089518 100644 --- a/README.md +++ b/README.md @@ -794,9 +794,9 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) - `repo`: Repository name (string, required) + - `reviewers`: GitHub usernames to request reviews from (string[], optional) - `state`: New state (string, optional) - `title`: New title (string, optional) - - `reviewers`: GitHub usernames to request reviews from (string[], optional) - **update_pull_request_branch** - Update pull request branch - `expectedHeadSha`: The expected SHA of the pull request's HEAD ref (string, optional) From 046f994328e0dd714e9d8980552971f88fbf2514 Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Fri, 25 Jul 2025 14:53:47 +0100 Subject: [PATCH 05/11] feat: enhance UpdatePullRequest to return early if no updates or reviewers are provided --- pkg/github/pullrequests.go | 37 ++++++------------------------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index de924bc68..33c7db2c8 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -304,6 +304,11 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } + // If no updates and no reviewers, return error early + if !updateNeeded && len(reviewers) == 0 { + return mcp.NewToolResultError("No update parameters provided"), nil + } + // Create the GitHub client client, err := getClient(ctx) if err != nil { @@ -313,7 +318,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu var pr *github.PullRequest var resp *http.Response - // First, update the PR if needed + // Update the PR if needed if updateNeeded { var ghResp *github.Response pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) @@ -338,31 +343,6 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil } - } else { - // If no update needed, just get the current PR - var ghResp *github.Response - pr, ghResp, err = client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get pull request", - ghResp, - err, - ), nil - } - resp = ghResp.Response - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil - } } // Add reviewers if specified @@ -398,11 +378,6 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu pr = updatedPR } - // If no updates and no reviewers, return error - if !updateNeeded && len(reviewers) == 0 { - return mcp.NewToolResultError("No update parameters provided"), nil - } - r, err := json.Marshal(pr) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) From 90eb11aa60196973e17cc350eaacdb5d4e4ee874 Mon Sep 17 00:00:00 2001 From: Matt Holloway Date: Tue, 29 Jul 2025 10:58:06 +0100 Subject: [PATCH 06/11] Add updating draft state to `update_pull_request` tool (#774) * initial impl of pull request draft state update * appease linter * update README * add nosec * fixed err return type for json marshalling * add gql test --- README.md | 1 + .../__toolsnaps__/update_pull_request.snap | 4 + pkg/github/pullrequests.go | 164 ++++++++----- pkg/github/pullrequests_test.go | 226 +++++++++++++++++- pkg/github/tools.go | 2 +- 5 files changed, 334 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 9ad17cfed..34cbd5e7a 100644 --- a/README.md +++ b/README.md @@ -736,6 +736,7 @@ The following sets of tools are available (all are on by default): - **update_pull_request** - Edit pull request - `base`: New base branch name (string, optional) - `body`: New description (string, optional) + - `draft`: Mark pull request as draft (true) or ready for review (false) (boolean, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) - `owner`: Repository owner (string, required) - `pullNumber`: Pull request number to update (number, required) diff --git a/pkg/github/__toolsnaps__/update_pull_request.snap b/pkg/github/__toolsnaps__/update_pull_request.snap index 621299e43..25170ed5f 100644 --- a/pkg/github/__toolsnaps__/update_pull_request.snap +++ b/pkg/github/__toolsnaps__/update_pull_request.snap @@ -14,6 +14,10 @@ "description": "New description", "type": "string" }, + "draft": { + "description": "Mark pull request as draft (true) or ready for review (false)", + "type": "boolean" + }, "maintainer_can_modify": { "description": "Allow maintainer edits", "type": "boolean" diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 33c7db2c8..384ad8fb8 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Description("New state"), mcp.Enum("open", "closed"), ), + mcp.WithBoolean("draft", + mcp.Description("Mark pull request as draft (true) or ready for review (false)"), + ), mcp.WithString("base", mcp.Description("New base branch name"), ), @@ -259,43 +262,51 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError(err.Error()), nil } - // Build the update struct only with provided fields + draftProvided := request.GetArguments()["draft"] != nil + var draftValue bool + if draftProvided { + draftValue, err = OptionalParam[bool](request, "draft") + if err != nil { + return nil, err + } + } + update := &github.PullRequest{} - updateNeeded := false + restUpdateNeeded := false if title, ok, err := OptionalParamOK[string](request, "title"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Title = github.Ptr(title) - updateNeeded = true + restUpdateNeeded = true } if body, ok, err := OptionalParamOK[string](request, "body"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Body = github.Ptr(body) - updateNeeded = true + restUpdateNeeded = true } if state, ok, err := OptionalParamOK[string](request, "state"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.State = github.Ptr(state) - updateNeeded = true + restUpdateNeeded = true } if base, ok, err := OptionalParamOK[string](request, "base"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)} - updateNeeded = true + restUpdateNeeded = true } if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil { return mcp.NewToolResultError(err.Error()), nil } else if ok { update.MaintainerCanModify = github.Ptr(maintainerCanModify) - updateNeeded = true + restUpdateNeeded = true } // Handle reviewers separately @@ -305,82 +316,115 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // If no updates and no reviewers, return error early - if !updateNeeded && len(reviewers) == 0 { + if !restUpdateNeeded && len(reviewers) == 0 && !draftProvided { return mcp.NewToolResultError("No update parameters provided"), nil } - // Create the GitHub client client, err := getClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } + pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() - var pr *github.PullRequest - var resp *http.Response - - // Update the PR if needed - if updateNeeded { - var ghResp *github.Response - pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - ghResp, - err, - ), nil + return nil, fmt.Errorf("failed to read response body: %w", err) } - resp = ghResp.Response - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + if draftProvided { + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) } - } - // Add reviewers if specified - if len(reviewers) > 0 { - reviewersRequest := github.ReviewersRequest{ - Reviewers: reviewers, + var prQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` } - // Use the direct result of RequestReviewers which includes the requested reviewers - updatedPR, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers + }) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to request reviewers", - resp, - err, - ), nil + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) + + if currentIsDraft != draftValue { + if draftValue { + // Convert to draft + var mutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil + } + } else { + // Mark as ready for review + var mutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: prQuery.Repository.PullRequest.ID, + }, nil) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil + } } - return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil } + } - // Use the updated PR with reviewers - pr = updatedPR + client, err := getClient(ctx) + if err != nil { + return nil, err } - r, err := json.Marshal(pr) + finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + r, err := json.Marshal(finalPR) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil } return mcp.NewToolResultText(string(r)), nil diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index cb1d67668..a6595f13b 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -137,7 +137,7 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -145,6 +145,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "owner") assert.Contains(t, tool.InputSchema.Properties, "repo") assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "draft") assert.Contains(t, tool.InputSchema.Properties, "title") assert.Contains(t, tool.InputSchema.Properties, "body") assert.Contains(t, tool.InputSchema.Properties, "state") @@ -161,6 +162,7 @@ func Test_UpdatePullRequest(t *testing.T) { HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), Body: github.Ptr("Updated test PR body."), MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), Base: &github.PullRequestBranch{ Ref: github.Ptr("develop"), }, @@ -212,6 +214,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockUpdatedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -236,6 +242,10 @@ func Test_UpdatePullRequest(t *testing.T) { mockResponse(t, http.StatusOK, mockClosedPR), ), ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockClosedPR, + ), ), requestArgs: map[string]interface{}{ "owner": "owner", @@ -247,6 +257,7 @@ func Test_UpdatePullRequest(t *testing.T) { expectedPR: mockClosedPR, }, { +<<<<<<< HEAD name: "successful PR update with reviewers", mockedClient: mock.NewMockedHTTPClient( mock.WithRequestMatch( @@ -261,12 +272,28 @@ func Test_UpdatePullRequest(t *testing.T) { mock.WithRequestMatch( mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, mockPRWithReviewers, +======= + name: "successful PR update (title only)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, +>>>>>>> d5e1f48 (Add updating draft state to `update_pull_request` tool (#774)) ), ), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", "pullNumber": float64(42), +<<<<<<< HEAD "reviewers": []interface{}{"reviewer1", "reviewer2"}, }, expectError: false, @@ -281,6 +308,16 @@ func Test_UpdatePullRequest(t *testing.T) { mockNoUpdatePR, ), ), +======= + "title": "Updated Test PR Title", + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient(), // No API call expected +>>>>>>> d5e1f48 (Add updating draft state to `update_pull_request` tool (#774)) requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -342,7 +379,7 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -412,6 +449,191 @@ func Test_UpdatePullRequest(t *testing.T) { } } +func Test_UpdatePullRequest_Draft(t *testing.T) { + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Test PR body."), + MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), // Updated to ready for review + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful draft update to ready for review", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "markPullRequestReadyForReview": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful convert pull request to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "convertPullRequestToDraft": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": true, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For draft-only tests, we need to mock both GraphQL and the final REST GET call + restClient := github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + )) + gqlClient := githubv4.NewClient(tc.mockedClient) + + _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Draft != nil { + assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index e01b7cc40..caa4f9cfe 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -87,7 +87,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG toolsets.NewServerTool(MergePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, t)), + toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), toolsets.NewServerTool(RequestCopilotReview(getClient, t)), // Reviews From 033f613f407662bd68c165b4f256ae20ff3a2964 Mon Sep 17 00:00:00 2001 From: Tommaso Moro <37270480+tommaso-moro@users.noreply.github.com> Date: Tue, 29 Jul 2025 15:17:03 +0200 Subject: [PATCH 07/11] Add support for org-level discussions in list_discussions tool (#775) * make repo optional, and default to .github when not provided. improve tool description * autogen * update tests * small copy paste error fixes --- README.md | 2 +- pkg/github/discussions.go | 12 ++++-- pkg/github/discussions_test.go | 77 +++++++++++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 34cbd5e7a..2b2bc23d9 100644 --- a/README.md +++ b/README.md @@ -466,7 +466,7 @@ The following sets of tools are available (all are on by default): - `orderBy`: Order discussions by field. If provided, the 'direction' also needs to be provided. (string, optional) - `owner`: Repository owner (string, required) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) - - `repo`: Repository name (string, required) + - `repo`: Repository name. If not provided, discussions will be queried at the organisation level. (string, optional) diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index fce07ecdb..905a1b709 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -119,7 +119,7 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_discussions", - mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository")), + mcp.WithDescription(t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_LIST_DISCUSSIONS_USER_TITLE", "List discussions"), ReadOnlyHint: ToBoolPtr(true), @@ -129,8 +129,7 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp mcp.Description("Repository owner"), ), mcp.WithString("repo", - mcp.Required(), - mcp.Description("Repository name"), + mcp.Description("Repository name. If not provided, discussions will be queried at the organisation level."), ), mcp.WithString("category", mcp.Description("Optional filter by discussion category ID. If provided, only discussions with this category are listed."), @@ -150,10 +149,15 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp if err != nil { return mcp.NewToolResultError(err.Error()), nil } - repo, err := RequiredParam[string](request, "repo") + repo, err := OptionalParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + // when not provided, default to the .github repository + // this will query discussions at the organisation level + if repo == "" { + repo = ".github" + } category, err := OptionalParam[string](request, "category") if err != nil { diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index aefaf2f8c..1fa90b403 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -50,6 +50,46 @@ var ( }, } + discussionsOrgLevel = []map[string]any{ + { + "number": 1, + "title": "Org Discussion 1 - Community Guidelines", + "createdAt": "2023-01-15T00:00:00Z", + "updatedAt": "2023-01-15T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/1", + "category": map[string]any{"name": "Announcements"}, + }, + { + "number": 2, + "title": "Org Discussion 2 - Roadmap 2023", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/2", + "category": map[string]any{"name": "General"}, + }, + { + "number": 3, + "title": "Org Discussion 3 - Roadmap 2024", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/3", + "category": map[string]any{"name": "General"}, + }, + { + "number": 4, + "title": "Org Discussion 4 - Roadmap 2025", + "createdAt": "2023-02-20T00:00:00Z", + "updatedAt": "2023-02-20T00:00:00Z", + "author": map[string]any{"login": "org-admin"}, + "url": "https://github.com/owner/.github/discussions/4", + "category": map[string]any{"name": "General"}, + }, + + } + // Ordered mock responses discussionsOrderedCreatedAsc = []map[string]any{ discussionsAll[0], // Discussion 1 (created 2023-01-01) @@ -139,6 +179,22 @@ var ( }, }, }) + + mockResponseOrgLevel = githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "discussions": map[string]any{ + "nodes": discussionsOrgLevel, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "", + "endCursor": "", + }, + "totalCount": 4, + }, + }, + }) + mockErrorRepoNotFound = githubv4mock.ErrorResponse("repository not found") ) @@ -151,7 +207,7 @@ func Test_ListDiscussions(t *testing.T) { assert.Contains(t, toolDef.InputSchema.Properties, "repo") assert.Contains(t, toolDef.InputSchema.Properties, "orderBy") assert.Contains(t, toolDef.InputSchema.Properties, "direction") - assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo"}) + assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner"}) // Variables matching what GraphQL receives after JSON marshaling/unmarshaling varsListAll := map[string]interface{}{ @@ -204,6 +260,13 @@ func Test_ListDiscussions(t *testing.T) { "after": (*string)(nil), } + varsOrgLevel := map[string]interface{}{ + "owner": "owner", + "repo": ".github", // This is what gets set when repo is not provided + "first": float64(30), + "after": (*string)(nil), + } + tests := []struct { name string reqParams map[string]interface{} @@ -314,6 +377,15 @@ func Test_ListDiscussions(t *testing.T) { expectError: true, errContains: "repository not found", }, + { + name: "list org-level discussions (no repo provided)", + reqParams: map[string]interface{}{ + "owner": "owner", + // repo is not provided, it will default to ".github" + }, + expectError: false, + expectedCount: 4, + }, } // Define the actual query strings that match the implementation @@ -351,6 +423,9 @@ func Test_ListDiscussions(t *testing.T) { case "repository not found error": matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsRepoNotFound, mockErrorRepoNotFound) httpClient = githubv4mock.NewMockedHTTPClient(matcher) + case "list org-level discussions (no repo provided)": + matcher := githubv4mock.NewQueryMatcher(qBasicNoOrder, varsOrgLevel, mockResponseOrgLevel) + httpClient = githubv4mock.NewMockedHTTPClient(matcher) } gqlClient := githubv4.NewClient(httpClient) From 94cef701ba9d15b5dba2baba3ef732d40d93c04a Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Tue, 29 Jul 2025 19:11:46 +0100 Subject: [PATCH 08/11] refactor: streamline UpdatePullRequest logic and enhance test cases for reviewer updates --- pkg/github/discussions_test.go | 3 +- pkg/github/pullrequests.go | 80 ++++++++++++++++++++++++--------- pkg/github/pullrequests_test.go | 42 +++++++---------- 3 files changed, 76 insertions(+), 49 deletions(-) diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 1fa90b403..945783ae1 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -87,7 +87,6 @@ var ( "url": "https://github.com/owner/.github/discussions/4", "category": map[string]any{"name": "General"}, }, - } // Ordered mock responses @@ -190,7 +189,7 @@ var ( "startCursor": "", "endCursor": "", }, - "totalCount": 4, + "totalCount": 4, }, }, }) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 384ad8fb8..b51e88e8f 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -315,33 +315,38 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra return mcp.NewToolResultError(err.Error()), nil } - // If no updates and no reviewers, return error early - if !restUpdateNeeded && len(reviewers) == 0 && !draftProvided { - return mcp.NewToolResultError("No update parameters provided"), nil + // If no updates, no draft change, and no reviewers, return error early + if !restUpdateNeeded && !draftProvided && len(reviewers) == 0 { + return mcp.NewToolResultError("No update parameters provided."), nil } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update pull request", - resp, - err, - ), nil - } - defer func() { _ = resp.Body.Close() }() + // Handle REST API updates + if restUpdateNeeded { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + _, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update pull request", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil } + // Handle draft status changes using GraphQL if draftProvided { gqlClient, err := getGQLClient(ctx) if err != nil { @@ -407,6 +412,41 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } } + // Handle reviewer requests + if len(reviewers) > 0 { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + _, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to request reviewers", + resp, + err, + ), nil + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil + } + } + + // Get the final state of the PR to return client, err := getClient(ctx) if err != nil { return nil, err diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index a6595f13b..57face355 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -257,58 +257,47 @@ func Test_UpdatePullRequest(t *testing.T) { expectedPR: mockClosedPR, }, { -<<<<<<< HEAD name: "successful PR update with reviewers", mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposPullsByOwnerByRepoByPullNumber, - &github.PullRequest{ - Number: github.Ptr(42), - Title: github.Ptr("Test PR"), - State: github.Ptr("open"), - }, - ), // Mock for RequestReviewers call, returning the PR with reviewers mock.WithRequestMatch( mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, mockPRWithReviewers, -======= - name: "successful PR update (title only)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PatchReposPullsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "title": "Updated Test PR Title", - }).andThen( - mockResponse(t, http.StatusOK, mockUpdatedPR), - ), ), mock.WithRequestMatch( mock.GetReposPullsByOwnerByRepoByPullNumber, - mockUpdatedPR, ->>>>>>> d5e1f48 (Add updating draft state to `update_pull_request` tool (#774)) + mockPRWithReviewers, ), ), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", "pullNumber": float64(42), -<<<<<<< HEAD "reviewers": []interface{}{"reviewer1", "reviewer2"}, }, expectError: false, expectedPR: mockPRWithReviewers, }, { - name: "no update parameters provided", + name: "successful PR update (title only)", mockedClient: mock.NewMockedHTTPClient( - // Mock a response for the GET PR request in case of no updates + mock.WithRequestMatchHandler( + mock.PatchReposPullsByOwnerByRepoByPullNumber, + expectRequestBody(t, map[string]interface{}{ + "title": "Updated Test PR Title", + }).andThen( + mockResponse(t, http.StatusOK, mockUpdatedPR), + ), + ), mock.WithRequestMatch( mock.GetReposPullsByOwnerByRepoByPullNumber, - mockNoUpdatePR, + mockUpdatedPR, ), ), -======= + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), "title": "Updated Test PR Title", }, expectError: false, @@ -317,7 +306,6 @@ func Test_UpdatePullRequest(t *testing.T) { { name: "no update parameters provided", mockedClient: mock.NewMockedHTTPClient(), // No API call expected ->>>>>>> d5e1f48 (Add updating draft state to `update_pull_request` tool (#774)) requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", From 5ea322b1fba4f503d856b1ca939ee32587067c1c Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Tue, 29 Jul 2025 19:42:12 +0100 Subject: [PATCH 09/11] refactor: remove redundant draft update tests and streamline UpdatePullRequest logic --- pkg/github/discussions_test.go | 1 - pkg/github/pullrequests.go | 96 +-------- pkg/github/pullrequests_test.go | 370 -------------------------------- 3 files changed, 4 insertions(+), 463 deletions(-) diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index dd965e5c6..945783ae1 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -87,7 +87,6 @@ var ( "url": "https://github.com/owner/.github/discussions/4", "category": map[string]any{"name": "General"}, }, - } // Ordered mock responses diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index bbfab30b3..f82117cad 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -262,15 +262,17 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra return mcp.NewToolResultError(err.Error()), nil } + // Check if draft parameter is provided draftProvided := request.GetArguments()["draft"] != nil var draftValue bool if draftProvided { draftValue, err = OptionalParam[bool](request, "draft") if err != nil { - return nil, err + return mcp.NewToolResultError(err.Error()), nil } } + // Build the update struct only with provided fields update := &github.PullRequest{} restUpdateNeeded := false @@ -320,13 +322,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra return mcp.NewToolResultError("No update parameters provided."), nil } - // Handle REST API updates - } - - if !restUpdateNeeded && !draftProvided { - return mcp.NewToolResultError("No update parameters provided."), nil - } - + // Handle REST API updates (title, body, state, base, maintainer_can_modify) if restUpdateNeeded { client, err := getClient(ctx) if err != nil { @@ -468,90 +464,6 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } }() - r, err := json.Marshal(finalPR) - if err != nil { - } - - if draftProvided { - gqlClient, err := getGQLClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) - } - - var prQuery struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers - }) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil - } - - currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft) - - if currentIsDraft != draftValue { - if draftValue { - // Convert to draft - var mutation struct { - ConvertPullRequestToDraft struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"convertPullRequestToDraft(input: $input)"` - } - - err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil - } - } else { - // Mark as ready for review - var mutation struct { - MarkPullRequestReadyForReview struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"markPullRequestReadyForReview(input: $input)"` - } - - err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{ - PullRequestID: prQuery.Repository.PullRequest.ID, - }, nil) - if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil - } - } - } - } - - client, err := getClient(ctx) - if err != nil { - return nil, err - } - - finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil - } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - r, err := json.Marshal(finalPR) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal response: %v", err)), nil diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 7ca12b5eb..896b71ac1 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -437,376 +437,6 @@ func Test_UpdatePullRequest(t *testing.T) { } } -func Test_UpdatePullRequest_Draft(t *testing.T) { - // Setup mock PR for success case - mockUpdatedPR := &github.PullRequest{ - Number: github.Ptr(42), - Title: github.Ptr("Test PR Title"), - State: github.Ptr("open"), - HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), - Body: github.Ptr("Test PR body."), - MaintainerCanModify: github.Ptr(false), - Draft: github.Ptr(false), // Updated to ready for review - Base: &github.PullRequestBranch{ - Ref: github.Ptr("main"), - }, - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedPR *github.PullRequest - expectedErrMsg string - }{ - { - name: "successful draft update to ready for review", - mockedClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - }{}, - map[string]any{ - "owner": githubv4.String("owner"), - "repo": githubv4.String("repo"), - "prNum": githubv4.Int(42), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": true, // Current state is draft - }, - }, - }), - ), - githubv4mock.NewMutationMatcher( - struct { - MarkPullRequestReadyForReview struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"markPullRequestReadyForReview(input: $input)"` - }{}, - githubv4.MarkPullRequestReadyForReviewInput{ - PullRequestID: "PR_kwDOA0xdyM50BPaO", - }, - nil, - githubv4mock.DataResponse(map[string]any{ - "markPullRequestReadyForReview": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": false, - }, - }, - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "draft": false, - }, - expectError: false, - expectedPR: mockUpdatedPR, - }, - { - name: "successful convert pull request to draft", - mockedClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - }{}, - map[string]any{ - "owner": githubv4.String("owner"), - "repo": githubv4.String("repo"), - "prNum": githubv4.Int(42), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": false, // Current state is draft - }, - }, - }), - ), - githubv4mock.NewMutationMatcher( - struct { - ConvertPullRequestToDraft struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"convertPullRequestToDraft(input: $input)"` - }{}, - githubv4.ConvertPullRequestToDraftInput{ - PullRequestID: "PR_kwDOA0xdyM50BPaO", - }, - nil, - githubv4mock.DataResponse(map[string]any{ - "convertPullRequestToDraft": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": true, - }, - }, - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "draft": true, - }, - expectError: false, - expectedPR: mockUpdatedPR, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - // For draft-only tests, we need to mock both GraphQL and the final REST GET call - restClient := github.NewClient(mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposPullsByOwnerByRepoByPullNumber, - mockUpdatedPR, - ), - )) - gqlClient := githubv4.NewClient(tc.mockedClient) - - _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) - - request := createMCPRequest(tc.requestArgs) - - result, err := handler(context.Background(), request) - - if tc.expectError || tc.expectedErrMsg != "" { - require.NoError(t, err) - require.True(t, result.IsError) - errorContent := getErrorResult(t, result) - if tc.expectedErrMsg != "" { - assert.Contains(t, errorContent.Text, tc.expectedErrMsg) - } - return - } - - require.NoError(t, err) - require.False(t, result.IsError) - - textContent := getTextResult(t, result) - - // Unmarshal and verify the successful result - var returnedPR github.PullRequest - err = json.Unmarshal([]byte(textContent.Text), &returnedPR) - require.NoError(t, err) - assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) - if tc.expectedPR.Draft != nil { - assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) - } - }) - } -} - -func Test_UpdatePullRequest_Draft(t *testing.T) { - // Setup mock PR for success case - mockUpdatedPR := &github.PullRequest{ - Number: github.Ptr(42), - Title: github.Ptr("Test PR Title"), - State: github.Ptr("open"), - HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), - Body: github.Ptr("Test PR body."), - MaintainerCanModify: github.Ptr(false), - Draft: github.Ptr(false), // Updated to ready for review - Base: &github.PullRequestBranch{ - Ref: github.Ptr("main"), - }, - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedPR *github.PullRequest - expectedErrMsg string - }{ - { - name: "successful draft update to ready for review", - mockedClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - }{}, - map[string]any{ - "owner": githubv4.String("owner"), - "repo": githubv4.String("repo"), - "prNum": githubv4.Int(42), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": true, // Current state is draft - }, - }, - }), - ), - githubv4mock.NewMutationMatcher( - struct { - MarkPullRequestReadyForReview struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"markPullRequestReadyForReview(input: $input)"` - }{}, - githubv4.MarkPullRequestReadyForReviewInput{ - PullRequestID: "PR_kwDOA0xdyM50BPaO", - }, - nil, - githubv4mock.DataResponse(map[string]any{ - "markPullRequestReadyForReview": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": false, - }, - }, - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "draft": false, - }, - expectError: false, - expectedPR: mockUpdatedPR, - }, - { - name: "successful convert pull request to draft", - mockedClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } `graphql:"pullRequest(number: $prNum)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - }{}, - map[string]any{ - "owner": githubv4.String("owner"), - "repo": githubv4.String("repo"), - "prNum": githubv4.Int(42), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": false, // Current state is draft - }, - }, - }), - ), - githubv4mock.NewMutationMatcher( - struct { - ConvertPullRequestToDraft struct { - PullRequest struct { - ID githubv4.ID - IsDraft githubv4.Boolean - } - } `graphql:"convertPullRequestToDraft(input: $input)"` - }{}, - githubv4.ConvertPullRequestToDraftInput{ - PullRequestID: "PR_kwDOA0xdyM50BPaO", - }, - nil, - githubv4mock.DataResponse(map[string]any{ - "convertPullRequestToDraft": map[string]any{ - "pullRequest": map[string]any{ - "id": "PR_kwDOA0xdyM50BPaO", - "isDraft": true, - }, - }, - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "draft": true, - }, - expectError: false, - expectedPR: mockUpdatedPR, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - // For draft-only tests, we need to mock both GraphQL and the final REST GET call - restClient := github.NewClient(mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposPullsByOwnerByRepoByPullNumber, - mockUpdatedPR, - ), - )) - gqlClient := githubv4.NewClient(tc.mockedClient) - - _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) - - request := createMCPRequest(tc.requestArgs) - - result, err := handler(context.Background(), request) - - if tc.expectError || tc.expectedErrMsg != "" { - require.NoError(t, err) - require.True(t, result.IsError) - errorContent := getErrorResult(t, result) - if tc.expectedErrMsg != "" { - assert.Contains(t, errorContent.Text, tc.expectedErrMsg) - } - return - } - - require.NoError(t, err) - require.False(t, result.IsError) - - textContent := getTextResult(t, result) - - // Unmarshal and verify the successful result - var returnedPR github.PullRequest - err = json.Unmarshal([]byte(textContent.Text), &returnedPR) - require.NoError(t, err) - assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) - if tc.expectedPR.Draft != nil { - assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) - } - }) - } -} - func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) From be359eaaba84452e61e43c0cba7932f563f4bda6 Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Tue, 29 Jul 2025 19:56:44 +0100 Subject: [PATCH 10/11] test: add unit tests for updating pull request draft state --- pkg/github/pullrequests_test.go | 201 ++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 896b71ac1..67c1e25e7 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -437,6 +437,207 @@ func Test_UpdatePullRequest(t *testing.T) { } } +func Test_UpdatePullRequest_Draft(t *testing.T) { + // Setup mock PR for success case + mockUpdatedPR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR Title"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + Body: github.Ptr("Test PR body."), + MaintainerCanModify: github.Ptr(false), + Draft: github.Ptr(false), // Updated to ready for review + Base: &github.PullRequestBranch{ + Ref: github.Ptr("main"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedPR *github.PullRequest + expectedErrMsg string + }{ + { + name: "successful draft update to ready for review", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "markPullRequestReadyForReview": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": false, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + { + name: "successful convert pull request to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": false, // Current state is draft + }, + }, + }), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: "PR_kwDOA0xdyM50BPaO", + }, + nil, + githubv4mock.DataResponse(map[string]any{ + "convertPullRequestToDraft": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDOA0xdyM50BPaO", + "isDraft": true, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "draft": true, + }, + expectError: false, + expectedPR: mockUpdatedPR, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // For draft-only tests, we need to mock both GraphQL and the final REST GET call + restClient := github.NewClient(mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockUpdatedPR, + ), + )) + gqlClient := githubv4.NewClient(tc.mockedClient) + + _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError || tc.expectedErrMsg != "" { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + if tc.expectedErrMsg != "" { + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + + // Unmarshal and verify the successful result + var returnedPR github.PullRequest + err = json.Unmarshal([]byte(textContent.Text), &returnedPR) + require.NoError(t, err) + assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) + if tc.expectedPR.Title != nil { + assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) + } + if tc.expectedPR.Body != nil { + assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body) + } + if tc.expectedPR.State != nil { + assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) + } + if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil { + assert.NotNil(t, returnedPR.Base) + assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref) + } + if tc.expectedPR.MaintainerCanModify != nil { + assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) + } + if tc.expectedPR.Draft != nil { + assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) + } + }) + } +} + func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) From c1b9a462b60b6da04efa788b904b49815ea8563d Mon Sep 17 00:00:00 2001 From: MayorFaj Date: Thu, 31 Jul 2025 01:14:57 +0100 Subject: [PATCH 11/11] refactor: simplify UpdatePullRequest tests by removing unused mock data --- pkg/github/pullrequests_test.go | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 67c1e25e7..3a99d9f46 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -175,12 +175,6 @@ func Test_UpdatePullRequest(t *testing.T) { } // Mock PR for when there are no updates but we still need a response - mockNoUpdatePR := &github.PullRequest{ - Number: github.Ptr(42), - Title: github.Ptr("Test PR"), - State: github.Ptr("open"), - } - mockPRWithReviewers := &github.PullRequest{ Number: github.Ptr(42), Title: github.Ptr("Test PR"), @@ -338,11 +332,6 @@ func Test_UpdatePullRequest(t *testing.T) { { name: "request reviewers fails", mockedClient: mock.NewMockedHTTPClient( - // First it gets the PR (no fields to update) - mock.WithRequestMatch( - mock.GetReposPullsByOwnerByRepoByPullNumber, - mockNoUpdatePR, - ), // Then reviewer request fails mock.WithRequestMatchHandler( mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, @@ -615,25 +604,6 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { err = json.Unmarshal([]byte(textContent.Text), &returnedPR) require.NoError(t, err) assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number) - if tc.expectedPR.Title != nil { - assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title) - } - if tc.expectedPR.Body != nil { - assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body) - } - if tc.expectedPR.State != nil { - assert.Equal(t, *tc.expectedPR.State, *returnedPR.State) - } - if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil { - assert.NotNil(t, returnedPR.Base) - assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref) - } - if tc.expectedPR.MaintainerCanModify != nil { - assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) - } - if tc.expectedPR.Draft != nil { - assert.Equal(t, *tc.expectedPR.Draft, *returnedPR.Draft) - } }) } }