Skip to content

Commit e48c0c7

Browse files
committed
Update list_pull_requests tool to add client-side author filtering
1 parent 72b6dc8 commit e48c0c7

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@ The following sets of tools are available:
10631063

10641064
- **list_pull_requests** - List pull requests
10651065
- **Required OAuth Scopes**: `repo`
1066+
- `author`: Filter by PR author username (client-side filter) (string, optional)
10661067
- `base`: Filter by base branch (string, optional)
10671068
- `direction`: Sort direction (string, optional)
10681069
- `head`: Filter by head user/org and branch (string, optional)

pkg/github/pullrequests.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,10 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool
10471047
Description: "Sort direction",
10481048
Enum: []any{"asc", "desc"},
10491049
},
1050+
"author": {
1051+
Type: "string",
1052+
Description: "Filter by PR author username (client-side filter)",
1053+
},
10501054
},
10511055
Required: []string{"owner", "repo"},
10521056
}
@@ -1056,7 +1060,7 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool
10561060
ToolsetMetadataPullRequests,
10571061
mcp.Tool{
10581062
Name: "list_pull_requests",
1059-
Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead. If you receive a 422 error from search_pull_requests, then use the get_prs_reviewed_by tool instead."),
1063+
Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead. If you receive a 422 error from search_pull_requests, then use the get_prs_reviewed_by tool (to list by reviewer), or this tool with the author parameter (for filtering by author) depending on what you need."),
10601064
Annotations: &mcp.ToolAnnotations{
10611065
Title: t("TOOL_LIST_PULL_REQUESTS_USER_TITLE", "List pull requests"),
10621066
ReadOnlyHint: true,
@@ -1093,6 +1097,10 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool
10931097
if err != nil {
10941098
return utils.NewToolResultError(err.Error()), nil, nil
10951099
}
1100+
author, err := OptionalParam[string](args, "author")
1101+
if err != nil {
1102+
return utils.NewToolResultError(err.Error()), nil, nil
1103+
}
10961104
pagination, err := OptionalPaginationParams(args)
10971105
if err != nil {
10981106
return utils.NewToolResultError(err.Error()), nil, nil
@@ -1132,6 +1140,18 @@ func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool
11321140
return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list pull requests", resp, bodyBytes), nil, nil
11331141
}
11341142

1143+
// Filter by author if specified (client-side filtering)
1144+
if author != "" {
1145+
filtered := make([]*github.PullRequest, 0)
1146+
for _, pr := range prs {
1147+
if pr != nil && pr.User != nil && pr.User.Login != nil &&
1148+
strings.EqualFold(*pr.User.Login, author) {
1149+
filtered = append(filtered, pr)
1150+
}
1151+
}
1152+
prs = filtered
1153+
}
1154+
11351155
// sanitize title/body on each PR
11361156
for _, pr := range prs {
11371157
if pr == nil {

pkg/github/pullrequests_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,120 @@ func Test_MergePullRequest(t *testing.T) {
978978
}
979979
}
980980

981+
func Test_ListPullRequests_AuthorFilter(t *testing.T) {
982+
// Setup mock PRs from multiple authors
983+
mockPRs := []*github.PullRequest{
984+
{
985+
Number: github.Ptr(42),
986+
Title: github.Ptr("PR by user1"),
987+
State: github.Ptr("open"),
988+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
989+
User: &github.User{
990+
Login: github.Ptr("user1"),
991+
},
992+
},
993+
{
994+
Number: github.Ptr(43),
995+
Title: github.Ptr("PR by user2"),
996+
State: github.Ptr("open"),
997+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/43"),
998+
User: &github.User{
999+
Login: github.Ptr("user2"),
1000+
},
1001+
},
1002+
{
1003+
Number: github.Ptr(44),
1004+
Title: github.Ptr("Another PR by user1"),
1005+
State: github.Ptr("closed"),
1006+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/44"),
1007+
User: &github.User{
1008+
Login: github.Ptr("user1"),
1009+
},
1010+
},
1011+
}
1012+
1013+
tests := []struct {
1014+
name string
1015+
author string
1016+
expectedCount int
1017+
expectedNumbers []int
1018+
}{
1019+
{
1020+
name: "filter by user1 returns 2 PRs",
1021+
author: "user1",
1022+
expectedCount: 2,
1023+
expectedNumbers: []int{42, 44},
1024+
},
1025+
{
1026+
name: "filter by user2 returns 1 PR",
1027+
author: "user2",
1028+
expectedCount: 1,
1029+
expectedNumbers: []int{43},
1030+
},
1031+
{
1032+
name: "filter by USER1 (case insensitive) returns 2 PRs",
1033+
author: "USER1",
1034+
expectedCount: 2,
1035+
expectedNumbers: []int{42, 44},
1036+
},
1037+
{
1038+
name: "filter by nonexistent user returns 0 PRs",
1039+
author: "nonexistent",
1040+
expectedCount: 0,
1041+
expectedNumbers: []int{},
1042+
},
1043+
{
1044+
name: "no author filter returns all PRs",
1045+
author: "",
1046+
expectedCount: 3,
1047+
expectedNumbers: []int{42, 43, 44},
1048+
},
1049+
}
1050+
1051+
for _, tc := range tests {
1052+
t.Run(tc.name, func(t *testing.T) {
1053+
mockedClient := MockHTTPClientWithHandlers(map[string]http.HandlerFunc{
1054+
GetReposPullsByOwnerByRepo: mockResponse(t, http.StatusOK, mockPRs),
1055+
})
1056+
1057+
client := github.NewClient(mockedClient)
1058+
serverTool := ListPullRequests(translations.NullTranslationHelper)
1059+
deps := BaseDeps{
1060+
Client: client,
1061+
}
1062+
handler := serverTool.Handler(deps)
1063+
1064+
requestArgs := map[string]interface{}{
1065+
"owner": "owner",
1066+
"repo": "repo",
1067+
}
1068+
if tc.author != "" {
1069+
requestArgs["author"] = tc.author
1070+
}
1071+
1072+
request := createMCPRequest(requestArgs)
1073+
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
1074+
1075+
require.NoError(t, err)
1076+
require.False(t, result.IsError)
1077+
1078+
textContent := getTextResult(t, result)
1079+
var returnedPRs []*github.PullRequest
1080+
err = json.Unmarshal([]byte(textContent.Text), &returnedPRs)
1081+
require.NoError(t, err)
1082+
1083+
assert.Len(t, returnedPRs, tc.expectedCount)
1084+
1085+
// Verify the expected PR numbers are returned
1086+
returnedNumbers := make([]int, len(returnedPRs))
1087+
for i, pr := range returnedPRs {
1088+
returnedNumbers[i] = *pr.Number
1089+
}
1090+
assert.ElementsMatch(t, tc.expectedNumbers, returnedNumbers)
1091+
})
1092+
}
1093+
}
1094+
9811095
func Test_SearchPullRequests(t *testing.T) {
9821096
serverTool := SearchPullRequests(translations.NullTranslationHelper)
9831097
tool := serverTool.Tool

0 commit comments

Comments
 (0)