diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 82c8cbc..1a7f4f7 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -263,6 +263,18 @@ def status(pull_request: str) -> None: "With --no-stack, we support only non-range identifiers, and will submit each commit " "listed in the command line.", ) +@click.option( + "--reviewer", + default=None, + help="Comma-separated list of GitHub usernames to add as reviewers to new PRs " + "(overrides .ghstackrc setting)", +) +@click.option( + "--label", + default=None, + help="Comma-separated list of labels to add to new PRs " + "(overrides .ghstackrc setting)", +) @click.option( "--direct/--no-direct", "direct_opt", @@ -286,6 +298,8 @@ def submit( base: Optional[str], revs: Tuple[str, ...], stack: bool, + reviewer: Optional[str], + label: Optional[str], ) -> None: """ Submit or update a PR stack @@ -307,6 +321,8 @@ def submit( revs=revs, stack=stack, direct_opt=direct_opt, + reviewer=reviewer if reviewer is not None else config.reviewer, + label=label if label is not None else config.label, ) diff --git a/src/ghstack/config.py b/src/ghstack/config.py index ff78fe9..735ca10 100644 --- a/src/ghstack/config.py +++ b/src/ghstack/config.py @@ -107,6 +107,10 @@ def get_gh_cli_credentials( ("github_url", str), # Name of the upstream remote ("remote_name", str), + # Default reviewers to add to new pull requests (comma-separated usernames) + ("reviewer", Optional[str]), + # Default labels to add to new pull requests (comma-separated labels) + ("label", Optional[str]), ], ) @@ -287,6 +291,16 @@ def read_config( else: remote_name = "origin" + if config.has_option("ghstack", "reviewer"): + reviewer = config.get("ghstack", "reviewer") + else: + reviewer = None + + if config.has_option("ghstack", "label"): + label = config.get("ghstack", "label") + else: + label = None + if write_back: with open(config_path, "w") as f: config.write(f) @@ -302,6 +316,8 @@ def read_config( default_project_dir=default_project_dir, github_url=github_url, remote_name=remote_name, + reviewer=reviewer, + label=label, ) logging.debug(f"conf = {conf}") return conf diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index d82e969..622b6af 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import dataclasses import os.path import re from dataclasses import dataclass @@ -271,6 +272,8 @@ class PullRequest(Node): # state: PullRequestState title: str url: str + reviewers: List[str] = dataclasses.field(default_factory=list) + labels: List[str] = dataclasses.field(default_factory=list) def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] @@ -473,6 +476,24 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: GitHubNumber(int(m.group(3))), cast(CreateIssueCommentInput, kwargs), ) + if m := re.match( + r"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers", path + ): + # Handle adding reviewers + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + reviewers = kwargs.get("reviewers", []) + pr.reviewers.extend(reviewers) + return {} + if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels", path): + # Handle adding labels + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + labels = kwargs.get("labels", []) + pr.labels.extend(labels) + return {} elif method == "patch": if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path): owner, name, number = m.groups() diff --git a/src/ghstack/submit.py b/src/ghstack/submit.py index f7910c5..b9568b9 100644 --- a/src/ghstack/submit.py +++ b/src/ghstack/submit.py @@ -343,6 +343,12 @@ class Submitter: # merged. If None, infer whether or not the PR should be direct or not. direct_opt: Optional[bool] = None + # Default reviewers to add to new pull requests (comma-separated usernames) + reviewer: Optional[str] = None + + # Default labels to add to new pull requests (comma-separated labels) + label: Optional[str] = None + # ~~~~~~~~~~~~~~~~~~~~~~~~ # Computed in post init @@ -1434,6 +1440,32 @@ def _create_pull_request( ) comment_id = rc["id"] + # Add reviewers if specified + if self.reviewer: + reviewers = [r.strip() for r in self.reviewer.split(",") if r.strip()] + if reviewers: + try: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/pulls/{number}/requested_reviewers", + reviewers=reviewers, + ) + logging.info(f"Added reviewers: {', '.join(reviewers)}") + except Exception as e: + logging.warning(f"Failed to add reviewers: {e}") + + # Add labels if specified + if self.label: + labels = [label.strip() for label in self.label.split(",") if label.strip()] + if labels: + try: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/labels", + labels=labels, + ) + logging.info(f"Added labels: {', '.join(labels)}") + except Exception as e: + logging.warning(f"Failed to add labels: {e}") + logging.info("Opened PR #{}".format(number)) pull_request_resolved = ghstack.diff.PullRequestResolved( diff --git a/src/ghstack/test_prelude.py b/src/ghstack/test_prelude.py index a46d5a9..567fadb 100644 --- a/src/ghstack/test_prelude.py +++ b/src/ghstack/test_prelude.py @@ -51,6 +51,8 @@ "get_sh", "get_upstream_sh", "get_github", + "get_pr_reviewers", + "get_pr_labels", "tick", "captured_output", ] @@ -194,6 +196,8 @@ def gh_submit( base: Optional[str] = None, revs: Sequence[str] = (), stack: bool = True, + reviewer: Optional[str] = None, + label: Optional[str] = None, ) -> List[ghstack.submit.DiffMeta]: self = CTX r = ghstack.submit.main( @@ -214,6 +218,8 @@ def gh_submit( revs=revs, stack=stack, check_invariants=True, + reviewer=reviewer, + label=label, ) self.check_global_github_invariants(self.direct) return r @@ -386,6 +392,26 @@ def is_direct() -> bool: return CTX.direct +def get_github() -> ghstack.github_fake.FakeGitHubEndpoint: + return CTX.github + + +def get_pr_reviewers(pr_number: int) -> List[str]: + """Get the reviewers for a PR number.""" + github = get_github() + repo = github.state.repository("pytorch", "pytorch") + pr = github.state.pull_request(repo, ghstack.github_fake.GitHubNumber(pr_number)) + return pr.reviewers + + +def get_pr_labels(pr_number: int) -> List[str]: + """Get the labels for a PR number.""" + github = get_github() + repo = github.state.repository("pytorch", "pytorch") + pr = github.state.pull_request(repo, ghstack.github_fake.GitHubNumber(pr_number)) + return pr.labels + + def assert_eq(a: Any, b: Any) -> None: assert a == b, f"{a} != {b}" @@ -419,9 +445,5 @@ def get_upstream_sh() -> ghstack.shell.Shell: return CTX.upstream_sh -def get_github() -> ghstack.github.GitHubEndpoint: - return CTX.github - - def tick() -> None: CTX.sh.test_tick() diff --git a/test/submit/cli_reviewer_and_label.py.test b/test/submit/cli_reviewer_and_label.py.test new file mode 100644 index 0000000..117ba2b --- /dev/null +++ b/test/submit/cli_reviewer_and_label.py.test @@ -0,0 +1,84 @@ +from ghstack.test_prelude import * + +init_test() + +# Create first commit with one set of reviewers/labels +commit("A") +(A,) = gh_submit("Initial commit", reviewer="reviewer1", label="bug") + +# Verify first PR has correct reviewers and labels +assert_eq(get_pr_reviewers(500), ["reviewer1"]) +assert_eq(get_pr_labels(500), ["bug"]) + +# Create second commit with different reviewers/labels +commit("B") +(A2, B) = gh_submit( + "Add B", reviewer="reviewer2,reviewer3", label="enhancement,priority-high" +) + +# Verify second PR has correct reviewers and labels +assert_eq(get_pr_reviewers(501), ["reviewer2", "reviewer3"]) +assert_eq(get_pr_labels(501), ["enhancement", "priority-high"]) + +if is_direct(): + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> master) + + This is commit A + + * 9cb2ede Initial commit + + [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) + + This is commit B + + * d012f5c Add B + + Repository state: + + * d012f5c (gh/ezyang/2/next, gh/ezyang/2/head) + | Add B + * 9cb2ede (gh/ezyang/1/next, gh/ezyang/1/head) + | Initial commit + * dc8bfe4 (HEAD -> master) + Initial commit + """ + ) +else: + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> gh/ezyang/1/base) + + Stack: + * #501 + * __->__ #500 + + This is commit A + + * 62602bd Initial commit + + [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/2/base) + + Stack: + * __->__ #501 + * #500 + + This is commit B + + * afd9aaf Add B + + Repository state: + + * afd9aaf (gh/ezyang/2/head) + | Add B + * 5d2de57 (gh/ezyang/2/base) + | Add B (base update) + | * 62602bd (gh/ezyang/1/head) + | | Initial commit + | * 5956f18 (gh/ezyang/1/base) + |/ Initial commit (base update) + * dc8bfe4 (HEAD -> master) + Initial commit + """, + ) diff --git a/test/submit/reviewer_and_label.py.test b/test/submit/reviewer_and_label.py.test new file mode 100644 index 0000000..57fe7f4 --- /dev/null +++ b/test/submit/reviewer_and_label.py.test @@ -0,0 +1,52 @@ +from ghstack.test_prelude import * + +init_test() + +commit("A") +(A,) = gh_submit( + "Initial commit", reviewer="reviewer1,reviewer2", label="bug,enhancement" +) + +# Verify reviewers and labels were added +assert_eq(get_pr_reviewers(500), ["reviewer1", "reviewer2"]) +assert_eq(get_pr_labels(500), ["bug", "enhancement"]) + +if is_direct(): + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> master) + + This is commit A + + * 9cb2ede Initial commit + + Repository state: + + * 9cb2ede (gh/ezyang/1/next, gh/ezyang/1/head) + | Initial commit + * dc8bfe4 (HEAD -> master) + Initial commit + """ + ) +else: + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> gh/ezyang/1/base) + + Stack: + * __->__ #500 + + This is commit A + + * 62602bd Initial commit + + Repository state: + + * 62602bd (gh/ezyang/1/head) + | Initial commit + * 5956f18 (gh/ezyang/1/base) + | Initial commit (base update) + * dc8bfe4 (HEAD -> master) + Initial commit + """, + )