From be043c5e321a5e45aa2a03db545bb0b480ac8940 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 22 Dec 2025 22:18:05 -0800 Subject: [PATCH 1/3] Add label and reviewer functionality to submit and ghstackrc --- lint.jso | 2 + lint.json | 1 + src/ghstack/cli.py | 16 +++++ src/ghstack/config.py | 16 +++++ src/ghstack/github_fake.py | 8 +++ src/ghstack/submit.py | 32 ++++++++++ src/ghstack/test_prelude.py | 4 ++ test/submit/cli_reviewer_and_label.py.test | 74 ++++++++++++++++++++++ test/submit/reviewer_and_label.py.test | 46 ++++++++++++++ 9 files changed, 199 insertions(+) create mode 100644 lint.jso create mode 100644 lint.json create mode 100644 test/submit/cli_reviewer_and_label.py.test create mode 100644 test/submit/reviewer_and_label.py.test diff --git a/lint.jso b/lint.jso new file mode 100644 index 0000000..e23a0f9 --- /dev/null +++ b/lint.jso @@ -0,0 +1,2 @@ +{"path":"/home/oulgen/dev/ghstack/src/ghstack/submit.py","line":1458,"char":37,"code":"FLAKE8","severity":"warning","name":"E741","description":"ambiguous variable name 'l'\nSee https://www.flake8rules.com/rules/E741.html"} +{"path":"/home/oulgen/dev/ghstack/src/ghstack/github_fake.py","line":null,"char":null,"code":"UFMT","severity":"warning","name":"format","description":"Run `lintrunner -a` to apply this patch.","original":"#!/usr/bin/env python3\n\nimport os.path\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, cast, Dict, List, NewType, Optional, Sequence\n\nimport graphql\nfrom typing_extensions import TypedDict\n\nimport ghstack.diff\nimport ghstack.github\nimport ghstack.shell\n\nGraphQLId = NewType(\"GraphQLId\", str)\nGitHubNumber = NewType(\"GitHubNumber\", int)\nGitObjectID = NewType(\"GitObjectID\", str)\n\n# https://stackoverflow.com/a/55250601\nSetDefaultBranchInput = TypedDict(\n \"SetDefaultBranchInput\",\n {\n \"name\": str,\n \"default_branch\": str,\n },\n)\n\nUpdatePullRequestInput = TypedDict(\n \"UpdatePullRequestInput\",\n {\n \"base\": Optional[str],\n \"title\": Optional[str],\n \"body\": Optional[str],\n },\n)\n\nCreatePullRequestInput = TypedDict(\n \"CreatePullRequestInput\",\n {\n \"base\": str,\n \"head\": str,\n \"title\": str,\n \"body\": str,\n \"maintainer_can_modify\": bool,\n },\n)\n\nCreateIssueCommentInput = TypedDict(\n \"CreateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreateIssueCommentPayload = TypedDict(\n \"CreateIssueCommentPayload\",\n {\n \"id\": int,\n },\n)\n\nUpdateIssueCommentInput = TypedDict(\n \"UpdateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreatePullRequestPayload = TypedDict(\n \"CreatePullRequestPayload\",\n {\n \"number\": int,\n },\n)\n\n\n# The \"database\" for our mock instance\nclass GitHubState:\n repositories: Dict[GraphQLId, \"Repository\"]\n pull_requests: Dict[GraphQLId, \"PullRequest\"]\n # This is very inefficient but whatever\n issue_comments: Dict[GraphQLId, \"IssueComment\"]\n _next_id: int\n # These are indexed by repo id\n _next_pull_request_number: Dict[GraphQLId, int]\n _next_issue_comment_full_database_id: Dict[GraphQLId, int]\n root: \"Root\"\n upstream_sh: Optional[ghstack.shell.Shell]\n\n def repository(self, owner: str, name: str) -> \"Repository\":\n nameWithOwner = \"{}/{}\".format(owner, name)\n for r in self.repositories.values():\n if r.nameWithOwner == nameWithOwner:\n return r\n raise RuntimeError(\"unknown repository {}\".format(nameWithOwner))\n\n def pull_request(self, repo: \"Repository\", number: GitHubNumber) -> \"PullRequest\":\n for pr in self.pull_requests.values():\n if repo.id == pr._repository and pr.number == number:\n return pr\n raise RuntimeError(\n \"unrecognized pull request #{} in repository {}\".format(\n number, repo.nameWithOwner\n )\n )\n\n def issue_comment(self, repo: \"Repository\", comment_id: int) -> \"IssueComment\":\n for comment in self.issue_comments.values():\n if repo.id == comment._repository and comment.fullDatabaseId == comment_id:\n return comment\n raise RuntimeError(\n f\"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}\"\n )\n\n def next_id(self) -> GraphQLId:\n r = GraphQLId(str(self._next_id))\n self._next_id += 1\n return r\n\n def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber:\n r = GitHubNumber(self._next_pull_request_number[repo_id])\n self._next_pull_request_number[repo_id] += 1\n return r\n\n def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int:\n r = self._next_issue_comment_full_database_id[repo_id]\n self._next_issue_comment_full_database_id[repo_id] += 1\n return r\n\n def push_hook(self, refs: Sequence[str]) -> None:\n # updated_refs = set(refs)\n # for pr in self.pull_requests:\n # # TODO: this assumes only origin repository\n # # if pr.headRefName in updated_refs:\n # # pr.headRef =\n # pass\n pass\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n repo = self.repository(pr_resolved.owner, pr_resolved.repo)\n pr = self.pull_request(repo, GitHubNumber(pr_resolved.number))\n pr.closed = True\n # TODO: model merged too\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:\n self.repositories = {}\n self.pull_requests = {}\n self.issue_comments = {}\n self._next_id = 5000\n self._next_pull_request_number = {}\n self._next_issue_comment_full_database_id = {}\n self.root = Root()\n\n # Populate it with the most important repo ;)\n repo = Repository(\n id=GraphQLId(\"1000\"),\n name=\"pytorch\",\n nameWithOwner=\"pytorch/pytorch\",\n isFork=False,\n defaultBranchRef=None,\n )\n self.repositories[GraphQLId(\"1000\")] = repo\n self._next_pull_request_number[GraphQLId(\"1000\")] = 500\n self._next_issue_comment_full_database_id[GraphQLId(\"1000\")] = 1500\n\n self.upstream_sh = upstream_sh\n if self.upstream_sh is not None:\n # Setup upstream Git repository representing the\n # pytorch/pytorch repository in the directory specified\n # by upstream_sh. This is useful because some GitHub API\n # operations depend on repository state (e.g., what\n # the headRef is at the time a PR is created), so\n # we need this information\n self.upstream_sh.git(\"init\", \"--bare\", \"-b\", \"master\")\n tree = self.upstream_sh.git(\"write-tree\")\n commit = self.upstream_sh.git(\"commit-tree\", tree, input=\"Initial commit\")\n self.upstream_sh.git(\"branch\", \"-f\", \"master\", commit)\n\n # We only update this when a PATCH changes the default\n # branch; hopefully that's fine? In any case, it should\n # work for now since currently we only ever access the name\n # of the default branch rather than other parts of its ref.\n repo.defaultBranchRef = repo._make_ref(self, \"master\")\n\n\n@dataclass\nclass Node:\n id: GraphQLId\n\n\nGraphQLResolveInfo = Any # for now\n\n\ndef github_state(info: GraphQLResolveInfo) -> GitHubState:\n context = info.context\n assert isinstance(context, GitHubState)\n return context\n\n\n@dataclass\nclass Repository(Node):\n name: str\n nameWithOwner: str\n isFork: bool\n defaultBranchRef: Optional[\"Ref\"]\n\n def pullRequest(\n self, info: GraphQLResolveInfo, number: GitHubNumber\n ) -> \"PullRequest\":\n return github_state(info).pull_request(self, number)\n\n def pullRequests(self, info: GraphQLResolveInfo) -> \"PullRequestConnection\":\n return PullRequestConnection(\n nodes=list(\n filter(\n lambda pr: self == pr.repository(info),\n github_state(info).pull_requests.values(),\n )\n )\n )\n\n # TODO: This should take which repository the ref is in\n # This only works if you have upstream_sh\n def _make_ref(self, state: GitHubState, refName: str) -> \"Ref\":\n # TODO: Probably should preserve object identity here when\n # you call this with refName/oid that are the same\n assert state.upstream_sh\n gitObject = GitObject(\n id=state.next_id(),\n # TODO: this upstream_sh hardcode wrong, but ok for now\n # because we only have one repo\n oid=GitObjectID(state.upstream_sh.git(\"rev-parse\", refName)),\n _repository=self.id,\n )\n ref = Ref(\n id=state.next_id(),\n name=refName,\n _repository=self.id,\n target=gitObject,\n )\n return ref\n\n\n@dataclass\nclass GitObject(Node):\n oid: GitObjectID\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass Ref(Node):\n name: str\n _repository: GraphQLId\n target: GitObject\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequest(Node):\n baseRef: Optional[Ref]\n baseRefName: str\n body: str\n closed: bool\n headRef: Optional[Ref]\n headRefName: str\n # headRepository: Optional[Repository]\n # maintainerCanModify: bool\n number: GitHubNumber\n _repository: GraphQLId # cycle breaker\n # state: PullRequestState\n title: str\n url: str\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass IssueComment(Node):\n body: str\n fullDatabaseId: int\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequestConnection:\n nodes: List[PullRequest]\n\n\nclass Root:\n def repository(self, info: GraphQLResolveInfo, owner: str, name: str) -> Repository:\n return github_state(info).repository(owner, name)\n\n def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node:\n if id in github_state(info).repositories:\n return github_state(info).repositories[id]\n elif id in github_state(info).pull_requests:\n return github_state(info).pull_requests[id]\n elif id in github_state(info).issue_comments:\n return github_state(info).issue_comments[id]\n else:\n raise RuntimeError(\"unknown id {}\".format(id))\n\n\nwith open(\n os.path.join(os.path.dirname(__file__), \"github_schema.graphql\"), encoding=\"utf-8\"\n) as f:\n GITHUB_SCHEMA = graphql.build_schema(f.read())\n\n\n# Ummm. I thought there would be a way to stick these on the objects\n# themselves (in the same way resolvers can be put on resolvers) but\n# after a quick read of default_resolve_type_fn it doesn't look like\n# we ever actually look to value for type of information. This is\n# pretty clunky lol.\ndef set_is_type_of(name: str, cls: Any) -> None:\n # Can't use a type ignore on the next line because fbcode\n # and us don't agree that it's necessary hmm.\n o: Any = GITHUB_SCHEMA.get_type(name)\n o.is_type_of = lambda obj, info: isinstance(obj, cls)\n\n\nset_is_type_of(\"Repository\", Repository)\nset_is_type_of(\"PullRequest\", PullRequest)\nset_is_type_of(\"IssueComment\", IssueComment)\n\n\nclass FakeGitHubEndpoint(ghstack.github.GitHubEndpoint):\n state: GitHubState\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell] = None) -> None:\n self.state = GitHubState(upstream_sh)\n\n def graphql(self, query: str, **kwargs: Any) -> Any:\n r = graphql.graphql_sync(\n schema=GITHUB_SCHEMA,\n source=query,\n root_value=self.state.root,\n context_value=self.state,\n variable_values=kwargs,\n )\n if r.errors:\n # The GraphQL implementation loses all the stack traces!!!\n # D: You can 'recover' them by deleting the\n # 'except Exception as error' from GraphQL-core-next; need\n # to file a bug report\n raise RuntimeError(\n \"GraphQL query failed with errors:\\n\\n{}\".format(\n \"\\n\".join(str(e) for e in r.errors)\n )\n )\n # The top-level object isn't indexable by strings, but\n # everything underneath is, oddly enough\n return {\"data\": r.data}\n\n def push_hook(self, refNames: Sequence[str]) -> None:\n self.state.push_hook(refNames)\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n self.state.notify_merged(pr_resolved)\n\n def _create_pull(\n self, owner: str, name: str, input: CreatePullRequestInput\n ) -> CreatePullRequestPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n number = state.next_pull_request_number(repo.id)\n baseRef = None\n headRef = None\n # TODO: When we support forks, this needs rewriting to stop\n # hard coded the repo we opened the pull request on\n if state.upstream_sh:\n baseRef = repo._make_ref(state, input[\"base\"])\n headRef = repo._make_ref(state, input[\"head\"])\n pr = PullRequest(\n id=id,\n _repository=repo.id,\n number=number,\n closed=False,\n url=\"https://github.com/{}/pull/{}\".format(repo.nameWithOwner, number),\n baseRef=baseRef,\n baseRefName=input[\"base\"],\n headRef=headRef,\n headRefName=input[\"head\"],\n title=input[\"title\"],\n body=input[\"body\"],\n )\n # TODO: compute files changed\n state.pull_requests[id] = pr\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"number\": number,\n }\n\n # NB: This technically does have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _update_pull(\n self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n pr = state.pull_request(repo, number)\n # If I say input.get('title') is not None, mypy\n # is unable to infer input['title'] is not None\n if \"title\" in input and input[\"title\"] is not None:\n pr.title = input[\"title\"]\n if \"base\" in input and input[\"base\"] is not None:\n pr.baseRefName = input[\"base\"]\n pr.baseRef = repo._make_ref(state, pr.baseRefName)\n if \"body\" in input and input[\"body\"] is not None:\n pr.body = input[\"body\"]\n\n def _create_issue_comment(\n self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput\n ) -> CreateIssueCommentPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n comment_id = state.next_issue_comment_full_database_id(repo.id)\n comment = IssueComment(\n id=id,\n _repository=repo.id,\n fullDatabaseId=comment_id,\n body=input[\"body\"],\n )\n state.issue_comments[id] = comment\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"id\": comment_id,\n }\n\n def _update_issue_comment(\n self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n comment = state.issue_comment(repo, comment_id)\n if (r := input.get(\"body\")) is not None:\n comment.body = r\n\n # NB: This may have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _set_default_branch(\n self, owner: str, name: str, input: SetDefaultBranchInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n repo.defaultBranchRef = repo._make_ref(state, input[\"default_branch\"])\n\n def rest(self, method: str, path: str, **kwargs: Any) -> Any:\n if method == \"get\":\n m = re.match(r\"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection\", path)\n if m:\n # For now, pretend all branches are not protected\n raise ghstack.github.NotFoundError()\n\n elif method == \"post\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls$\", path):\n return self._create_pull(\n m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments\", path):\n return self._create_issue_comment(\n m.group(1),\n m.group(2),\n GitHubNumber(int(m.group(3))),\n cast(CreateIssueCommentInput, kwargs),\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers\", path):\n # Handle adding reviewers - just return success for testing\n return {}\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels\", path):\n # Handle adding labels - just return success for testing\n return {}\n elif method == \"patch\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$\", path):\n owner, name, number = m.groups()\n if number is not None:\n return self._update_pull(\n owner,\n name,\n GitHubNumber(int(number)),\n cast(UpdatePullRequestInput, kwargs),\n )\n elif \"default_branch\" in kwargs:\n return self._set_default_branch(\n owner, name, cast(SetDefaultBranchInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$\", path):\n return self._update_issue_comment(\n m.group(1),\n m.group(2),\n int(m.group(3)),\n cast(UpdateIssueCommentInput, kwargs),\n )\n raise NotImplementedError(\n \"FakeGitHubEndpoint REST {} {} not implemented\".format(method.upper(), path)\n )\n","replacement":"#!/usr/bin/env python3\n\nimport os.path\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, cast, Dict, List, NewType, Optional, Sequence\n\nimport graphql\nfrom typing_extensions import TypedDict\n\nimport ghstack.diff\nimport ghstack.github\nimport ghstack.shell\n\nGraphQLId = NewType(\"GraphQLId\", str)\nGitHubNumber = NewType(\"GitHubNumber\", int)\nGitObjectID = NewType(\"GitObjectID\", str)\n\n# https://stackoverflow.com/a/55250601\nSetDefaultBranchInput = TypedDict(\n \"SetDefaultBranchInput\",\n {\n \"name\": str,\n \"default_branch\": str,\n },\n)\n\nUpdatePullRequestInput = TypedDict(\n \"UpdatePullRequestInput\",\n {\n \"base\": Optional[str],\n \"title\": Optional[str],\n \"body\": Optional[str],\n },\n)\n\nCreatePullRequestInput = TypedDict(\n \"CreatePullRequestInput\",\n {\n \"base\": str,\n \"head\": str,\n \"title\": str,\n \"body\": str,\n \"maintainer_can_modify\": bool,\n },\n)\n\nCreateIssueCommentInput = TypedDict(\n \"CreateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreateIssueCommentPayload = TypedDict(\n \"CreateIssueCommentPayload\",\n {\n \"id\": int,\n },\n)\n\nUpdateIssueCommentInput = TypedDict(\n \"UpdateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreatePullRequestPayload = TypedDict(\n \"CreatePullRequestPayload\",\n {\n \"number\": int,\n },\n)\n\n\n# The \"database\" for our mock instance\nclass GitHubState:\n repositories: Dict[GraphQLId, \"Repository\"]\n pull_requests: Dict[GraphQLId, \"PullRequest\"]\n # This is very inefficient but whatever\n issue_comments: Dict[GraphQLId, \"IssueComment\"]\n _next_id: int\n # These are indexed by repo id\n _next_pull_request_number: Dict[GraphQLId, int]\n _next_issue_comment_full_database_id: Dict[GraphQLId, int]\n root: \"Root\"\n upstream_sh: Optional[ghstack.shell.Shell]\n\n def repository(self, owner: str, name: str) -> \"Repository\":\n nameWithOwner = \"{}/{}\".format(owner, name)\n for r in self.repositories.values():\n if r.nameWithOwner == nameWithOwner:\n return r\n raise RuntimeError(\"unknown repository {}\".format(nameWithOwner))\n\n def pull_request(self, repo: \"Repository\", number: GitHubNumber) -> \"PullRequest\":\n for pr in self.pull_requests.values():\n if repo.id == pr._repository and pr.number == number:\n return pr\n raise RuntimeError(\n \"unrecognized pull request #{} in repository {}\".format(\n number, repo.nameWithOwner\n )\n )\n\n def issue_comment(self, repo: \"Repository\", comment_id: int) -> \"IssueComment\":\n for comment in self.issue_comments.values():\n if repo.id == comment._repository and comment.fullDatabaseId == comment_id:\n return comment\n raise RuntimeError(\n f\"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}\"\n )\n\n def next_id(self) -> GraphQLId:\n r = GraphQLId(str(self._next_id))\n self._next_id += 1\n return r\n\n def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber:\n r = GitHubNumber(self._next_pull_request_number[repo_id])\n self._next_pull_request_number[repo_id] += 1\n return r\n\n def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int:\n r = self._next_issue_comment_full_database_id[repo_id]\n self._next_issue_comment_full_database_id[repo_id] += 1\n return r\n\n def push_hook(self, refs: Sequence[str]) -> None:\n # updated_refs = set(refs)\n # for pr in self.pull_requests:\n # # TODO: this assumes only origin repository\n # # if pr.headRefName in updated_refs:\n # # pr.headRef =\n # pass\n pass\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n repo = self.repository(pr_resolved.owner, pr_resolved.repo)\n pr = self.pull_request(repo, GitHubNumber(pr_resolved.number))\n pr.closed = True\n # TODO: model merged too\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:\n self.repositories = {}\n self.pull_requests = {}\n self.issue_comments = {}\n self._next_id = 5000\n self._next_pull_request_number = {}\n self._next_issue_comment_full_database_id = {}\n self.root = Root()\n\n # Populate it with the most important repo ;)\n repo = Repository(\n id=GraphQLId(\"1000\"),\n name=\"pytorch\",\n nameWithOwner=\"pytorch/pytorch\",\n isFork=False,\n defaultBranchRef=None,\n )\n self.repositories[GraphQLId(\"1000\")] = repo\n self._next_pull_request_number[GraphQLId(\"1000\")] = 500\n self._next_issue_comment_full_database_id[GraphQLId(\"1000\")] = 1500\n\n self.upstream_sh = upstream_sh\n if self.upstream_sh is not None:\n # Setup upstream Git repository representing the\n # pytorch/pytorch repository in the directory specified\n # by upstream_sh. This is useful because some GitHub API\n # operations depend on repository state (e.g., what\n # the headRef is at the time a PR is created), so\n # we need this information\n self.upstream_sh.git(\"init\", \"--bare\", \"-b\", \"master\")\n tree = self.upstream_sh.git(\"write-tree\")\n commit = self.upstream_sh.git(\"commit-tree\", tree, input=\"Initial commit\")\n self.upstream_sh.git(\"branch\", \"-f\", \"master\", commit)\n\n # We only update this when a PATCH changes the default\n # branch; hopefully that's fine? In any case, it should\n # work for now since currently we only ever access the name\n # of the default branch rather than other parts of its ref.\n repo.defaultBranchRef = repo._make_ref(self, \"master\")\n\n\n@dataclass\nclass Node:\n id: GraphQLId\n\n\nGraphQLResolveInfo = Any # for now\n\n\ndef github_state(info: GraphQLResolveInfo) -> GitHubState:\n context = info.context\n assert isinstance(context, GitHubState)\n return context\n\n\n@dataclass\nclass Repository(Node):\n name: str\n nameWithOwner: str\n isFork: bool\n defaultBranchRef: Optional[\"Ref\"]\n\n def pullRequest(\n self, info: GraphQLResolveInfo, number: GitHubNumber\n ) -> \"PullRequest\":\n return github_state(info).pull_request(self, number)\n\n def pullRequests(self, info: GraphQLResolveInfo) -> \"PullRequestConnection\":\n return PullRequestConnection(\n nodes=list(\n filter(\n lambda pr: self == pr.repository(info),\n github_state(info).pull_requests.values(),\n )\n )\n )\n\n # TODO: This should take which repository the ref is in\n # This only works if you have upstream_sh\n def _make_ref(self, state: GitHubState, refName: str) -> \"Ref\":\n # TODO: Probably should preserve object identity here when\n # you call this with refName/oid that are the same\n assert state.upstream_sh\n gitObject = GitObject(\n id=state.next_id(),\n # TODO: this upstream_sh hardcode wrong, but ok for now\n # because we only have one repo\n oid=GitObjectID(state.upstream_sh.git(\"rev-parse\", refName)),\n _repository=self.id,\n )\n ref = Ref(\n id=state.next_id(),\n name=refName,\n _repository=self.id,\n target=gitObject,\n )\n return ref\n\n\n@dataclass\nclass GitObject(Node):\n oid: GitObjectID\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass Ref(Node):\n name: str\n _repository: GraphQLId\n target: GitObject\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequest(Node):\n baseRef: Optional[Ref]\n baseRefName: str\n body: str\n closed: bool\n headRef: Optional[Ref]\n headRefName: str\n # headRepository: Optional[Repository]\n # maintainerCanModify: bool\n number: GitHubNumber\n _repository: GraphQLId # cycle breaker\n # state: PullRequestState\n title: str\n url: str\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass IssueComment(Node):\n body: str\n fullDatabaseId: int\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequestConnection:\n nodes: List[PullRequest]\n\n\nclass Root:\n def repository(self, info: GraphQLResolveInfo, owner: str, name: str) -> Repository:\n return github_state(info).repository(owner, name)\n\n def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node:\n if id in github_state(info).repositories:\n return github_state(info).repositories[id]\n elif id in github_state(info).pull_requests:\n return github_state(info).pull_requests[id]\n elif id in github_state(info).issue_comments:\n return github_state(info).issue_comments[id]\n else:\n raise RuntimeError(\"unknown id {}\".format(id))\n\n\nwith open(\n os.path.join(os.path.dirname(__file__), \"github_schema.graphql\"), encoding=\"utf-8\"\n) as f:\n GITHUB_SCHEMA = graphql.build_schema(f.read())\n\n\n# Ummm. I thought there would be a way to stick these on the objects\n# themselves (in the same way resolvers can be put on resolvers) but\n# after a quick read of default_resolve_type_fn it doesn't look like\n# we ever actually look to value for type of information. This is\n# pretty clunky lol.\ndef set_is_type_of(name: str, cls: Any) -> None:\n # Can't use a type ignore on the next line because fbcode\n # and us don't agree that it's necessary hmm.\n o: Any = GITHUB_SCHEMA.get_type(name)\n o.is_type_of = lambda obj, info: isinstance(obj, cls)\n\n\nset_is_type_of(\"Repository\", Repository)\nset_is_type_of(\"PullRequest\", PullRequest)\nset_is_type_of(\"IssueComment\", IssueComment)\n\n\nclass FakeGitHubEndpoint(ghstack.github.GitHubEndpoint):\n state: GitHubState\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell] = None) -> None:\n self.state = GitHubState(upstream_sh)\n\n def graphql(self, query: str, **kwargs: Any) -> Any:\n r = graphql.graphql_sync(\n schema=GITHUB_SCHEMA,\n source=query,\n root_value=self.state.root,\n context_value=self.state,\n variable_values=kwargs,\n )\n if r.errors:\n # The GraphQL implementation loses all the stack traces!!!\n # D: You can 'recover' them by deleting the\n # 'except Exception as error' from GraphQL-core-next; need\n # to file a bug report\n raise RuntimeError(\n \"GraphQL query failed with errors:\\n\\n{}\".format(\n \"\\n\".join(str(e) for e in r.errors)\n )\n )\n # The top-level object isn't indexable by strings, but\n # everything underneath is, oddly enough\n return {\"data\": r.data}\n\n def push_hook(self, refNames: Sequence[str]) -> None:\n self.state.push_hook(refNames)\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n self.state.notify_merged(pr_resolved)\n\n def _create_pull(\n self, owner: str, name: str, input: CreatePullRequestInput\n ) -> CreatePullRequestPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n number = state.next_pull_request_number(repo.id)\n baseRef = None\n headRef = None\n # TODO: When we support forks, this needs rewriting to stop\n # hard coded the repo we opened the pull request on\n if state.upstream_sh:\n baseRef = repo._make_ref(state, input[\"base\"])\n headRef = repo._make_ref(state, input[\"head\"])\n pr = PullRequest(\n id=id,\n _repository=repo.id,\n number=number,\n closed=False,\n url=\"https://github.com/{}/pull/{}\".format(repo.nameWithOwner, number),\n baseRef=baseRef,\n baseRefName=input[\"base\"],\n headRef=headRef,\n headRefName=input[\"head\"],\n title=input[\"title\"],\n body=input[\"body\"],\n )\n # TODO: compute files changed\n state.pull_requests[id] = pr\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"number\": number,\n }\n\n # NB: This technically does have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _update_pull(\n self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n pr = state.pull_request(repo, number)\n # If I say input.get('title') is not None, mypy\n # is unable to infer input['title'] is not None\n if \"title\" in input and input[\"title\"] is not None:\n pr.title = input[\"title\"]\n if \"base\" in input and input[\"base\"] is not None:\n pr.baseRefName = input[\"base\"]\n pr.baseRef = repo._make_ref(state, pr.baseRefName)\n if \"body\" in input and input[\"body\"] is not None:\n pr.body = input[\"body\"]\n\n def _create_issue_comment(\n self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput\n ) -> CreateIssueCommentPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n comment_id = state.next_issue_comment_full_database_id(repo.id)\n comment = IssueComment(\n id=id,\n _repository=repo.id,\n fullDatabaseId=comment_id,\n body=input[\"body\"],\n )\n state.issue_comments[id] = comment\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"id\": comment_id,\n }\n\n def _update_issue_comment(\n self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n comment = state.issue_comment(repo, comment_id)\n if (r := input.get(\"body\")) is not None:\n comment.body = r\n\n # NB: This may have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _set_default_branch(\n self, owner: str, name: str, input: SetDefaultBranchInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n repo.defaultBranchRef = repo._make_ref(state, input[\"default_branch\"])\n\n def rest(self, method: str, path: str, **kwargs: Any) -> Any:\n if method == \"get\":\n m = re.match(r\"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection\", path)\n if m:\n # For now, pretend all branches are not protected\n raise ghstack.github.NotFoundError()\n\n elif method == \"post\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls$\", path):\n return self._create_pull(\n m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments\", path):\n return self._create_issue_comment(\n m.group(1),\n m.group(2),\n GitHubNumber(int(m.group(3))),\n cast(CreateIssueCommentInput, kwargs),\n )\n if m := re.match(\n r\"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers\", path\n ):\n # Handle adding reviewers - just return success for testing\n return {}\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels\", path):\n # Handle adding labels - just return success for testing\n return {}\n elif method == \"patch\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$\", path):\n owner, name, number = m.groups()\n if number is not None:\n return self._update_pull(\n owner,\n name,\n GitHubNumber(int(number)),\n cast(UpdatePullRequestInput, kwargs),\n )\n elif \"default_branch\" in kwargs:\n return self._set_default_branch(\n owner, name, cast(SetDefaultBranchInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$\", path):\n return self._update_issue_comment(\n m.group(1),\n m.group(2),\n int(m.group(3)),\n cast(UpdateIssueCommentInput, kwargs),\n )\n raise NotImplementedError(\n \"FakeGitHubEndpoint REST {} {} not implemented\".format(method.upper(), path)\n )\n"} diff --git a/lint.json b/lint.json new file mode 100644 index 0000000..59ab1d7 --- /dev/null +++ b/lint.json @@ -0,0 +1 @@ +{"path":"/home/oulgen/dev/ghstack/src/ghstack/submit.py","line":1458,"char":37,"code":"FLAKE8","severity":"warning","name":"E741","description":"ambiguous variable name 'l'\nSee https://www.flake8rules.com/rules/E741.html"} diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 82c8cbc..1a7f4f7 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -263,6 +263,18 @@ def status(pull_request: str) -> None: "With --no-stack, we support only non-range identifiers, and will submit each commit " "listed in the command line.", ) +@click.option( + "--reviewer", + default=None, + help="Comma-separated list of GitHub usernames to add as reviewers to new PRs " + "(overrides .ghstackrc setting)", +) +@click.option( + "--label", + default=None, + help="Comma-separated list of labels to add to new PRs " + "(overrides .ghstackrc setting)", +) @click.option( "--direct/--no-direct", "direct_opt", @@ -286,6 +298,8 @@ def submit( base: Optional[str], revs: Tuple[str, ...], stack: bool, + reviewer: Optional[str], + label: Optional[str], ) -> None: """ Submit or update a PR stack @@ -307,6 +321,8 @@ def submit( revs=revs, stack=stack, direct_opt=direct_opt, + reviewer=reviewer if reviewer is not None else config.reviewer, + label=label if label is not None else config.label, ) diff --git a/src/ghstack/config.py b/src/ghstack/config.py index ff78fe9..735ca10 100644 --- a/src/ghstack/config.py +++ b/src/ghstack/config.py @@ -107,6 +107,10 @@ def get_gh_cli_credentials( ("github_url", str), # Name of the upstream remote ("remote_name", str), + # Default reviewers to add to new pull requests (comma-separated usernames) + ("reviewer", Optional[str]), + # Default labels to add to new pull requests (comma-separated labels) + ("label", Optional[str]), ], ) @@ -287,6 +291,16 @@ def read_config( else: remote_name = "origin" + if config.has_option("ghstack", "reviewer"): + reviewer = config.get("ghstack", "reviewer") + else: + reviewer = None + + if config.has_option("ghstack", "label"): + label = config.get("ghstack", "label") + else: + label = None + if write_back: with open(config_path, "w") as f: config.write(f) @@ -302,6 +316,8 @@ def read_config( default_project_dir=default_project_dir, github_url=github_url, remote_name=remote_name, + reviewer=reviewer, + label=label, ) logging.debug(f"conf = {conf}") return conf diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index d82e969..2b25d8e 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -473,6 +473,14 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: GitHubNumber(int(m.group(3))), cast(CreateIssueCommentInput, kwargs), ) + if m := re.match( + r"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers", path + ): + # Handle adding reviewers - just return success for testing + return {} + if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels", path): + # Handle adding labels - just return success for testing + return {} elif method == "patch": if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path): owner, name, number = m.groups() diff --git a/src/ghstack/submit.py b/src/ghstack/submit.py index f7910c5..b9568b9 100644 --- a/src/ghstack/submit.py +++ b/src/ghstack/submit.py @@ -343,6 +343,12 @@ class Submitter: # merged. If None, infer whether or not the PR should be direct or not. direct_opt: Optional[bool] = None + # Default reviewers to add to new pull requests (comma-separated usernames) + reviewer: Optional[str] = None + + # Default labels to add to new pull requests (comma-separated labels) + label: Optional[str] = None + # ~~~~~~~~~~~~~~~~~~~~~~~~ # Computed in post init @@ -1434,6 +1440,32 @@ def _create_pull_request( ) comment_id = rc["id"] + # Add reviewers if specified + if self.reviewer: + reviewers = [r.strip() for r in self.reviewer.split(",") if r.strip()] + if reviewers: + try: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/pulls/{number}/requested_reviewers", + reviewers=reviewers, + ) + logging.info(f"Added reviewers: {', '.join(reviewers)}") + except Exception as e: + logging.warning(f"Failed to add reviewers: {e}") + + # Add labels if specified + if self.label: + labels = [label.strip() for label in self.label.split(",") if label.strip()] + if labels: + try: + self.github.post( + f"repos/{self.repo_owner}/{self.repo_name}/issues/{number}/labels", + labels=labels, + ) + logging.info(f"Added labels: {', '.join(labels)}") + except Exception as e: + logging.warning(f"Failed to add labels: {e}") + logging.info("Opened PR #{}".format(number)) pull_request_resolved = ghstack.diff.PullRequestResolved( diff --git a/src/ghstack/test_prelude.py b/src/ghstack/test_prelude.py index a46d5a9..0861db8 100644 --- a/src/ghstack/test_prelude.py +++ b/src/ghstack/test_prelude.py @@ -194,6 +194,8 @@ def gh_submit( base: Optional[str] = None, revs: Sequence[str] = (), stack: bool = True, + reviewer: Optional[str] = None, + label: Optional[str] = None, ) -> List[ghstack.submit.DiffMeta]: self = CTX r = ghstack.submit.main( @@ -214,6 +216,8 @@ def gh_submit( revs=revs, stack=stack, check_invariants=True, + reviewer=reviewer, + label=label, ) self.check_global_github_invariants(self.direct) return r diff --git a/test/submit/cli_reviewer_and_label.py.test b/test/submit/cli_reviewer_and_label.py.test new file mode 100644 index 0000000..50964a7 --- /dev/null +++ b/test/submit/cli_reviewer_and_label.py.test @@ -0,0 +1,74 @@ +from ghstack.test_prelude import * + +init_test() + +# Create first commit with one set of reviewers/labels +commit("A") +(A,) = gh_submit("Initial commit", reviewer="reviewer1", label="bug") + +# Create second commit with different reviewers/labels +commit("B") +(A2, B) = gh_submit("Add B", reviewer="reviewer2,reviewer3", label="enhancement,priority-high") + +if is_direct(): + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> master) + + This is commit A + + * 9cb2ede Initial commit + + [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/1/head) + + This is commit B + + * d012f5c Add B + + Repository state: + + * d012f5c (gh/ezyang/2/next, gh/ezyang/2/head) + | Add B + * 9cb2ede (gh/ezyang/1/next, gh/ezyang/1/head) + | Initial commit + * dc8bfe4 (HEAD -> master) + Initial commit + """ + ) +else: + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> gh/ezyang/1/base) + + Stack: + * #501 + * __->__ #500 + + This is commit A + + * 62602bd Initial commit + + [O] #501 Commit B (gh/ezyang/2/head -> gh/ezyang/2/base) + + Stack: + * __->__ #501 + * #500 + + This is commit B + + * afd9aaf Add B + + Repository state: + + * afd9aaf (gh/ezyang/2/head) + | Add B + * 5d2de57 (gh/ezyang/2/base) + | Add B (base update) + | * 62602bd (gh/ezyang/1/head) + | | Initial commit + | * 5956f18 (gh/ezyang/1/base) + |/ Initial commit (base update) + * dc8bfe4 (HEAD -> master) + Initial commit + """, + ) diff --git a/test/submit/reviewer_and_label.py.test b/test/submit/reviewer_and_label.py.test new file mode 100644 index 0000000..392de7a --- /dev/null +++ b/test/submit/reviewer_and_label.py.test @@ -0,0 +1,46 @@ +from ghstack.test_prelude import * + +init_test() + +commit("A") +(A,) = gh_submit("Initial commit", reviewer="reviewer1,reviewer2", label="bug,enhancement") + +if is_direct(): + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> master) + + This is commit A + + * 9cb2ede Initial commit + + Repository state: + + * 9cb2ede (gh/ezyang/1/next, gh/ezyang/1/head) + | Initial commit + * dc8bfe4 (HEAD -> master) + Initial commit + """ + ) +else: + assert_github_state( + """\ + [O] #500 Commit A (gh/ezyang/1/head -> gh/ezyang/1/base) + + Stack: + * __->__ #500 + + This is commit A + + * 62602bd Initial commit + + Repository state: + + * 62602bd (gh/ezyang/1/head) + | Initial commit + * 5956f18 (gh/ezyang/1/base) + | Initial commit (base update) + * dc8bfe4 (HEAD -> master) + Initial commit + """, + ) From 987c2bcadb6d8d181476c177119ad61d4cf0aab5 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 22 Dec 2025 22:23:45 -0800 Subject: [PATCH 2/3] del bad files --- lint.jso | 2 -- lint.json | 1 - 2 files changed, 3 deletions(-) delete mode 100644 lint.jso delete mode 100644 lint.json diff --git a/lint.jso b/lint.jso deleted file mode 100644 index e23a0f9..0000000 --- a/lint.jso +++ /dev/null @@ -1,2 +0,0 @@ -{"path":"/home/oulgen/dev/ghstack/src/ghstack/submit.py","line":1458,"char":37,"code":"FLAKE8","severity":"warning","name":"E741","description":"ambiguous variable name 'l'\nSee https://www.flake8rules.com/rules/E741.html"} -{"path":"/home/oulgen/dev/ghstack/src/ghstack/github_fake.py","line":null,"char":null,"code":"UFMT","severity":"warning","name":"format","description":"Run `lintrunner -a` to apply this patch.","original":"#!/usr/bin/env python3\n\nimport os.path\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, cast, Dict, List, NewType, Optional, Sequence\n\nimport graphql\nfrom typing_extensions import TypedDict\n\nimport ghstack.diff\nimport ghstack.github\nimport ghstack.shell\n\nGraphQLId = NewType(\"GraphQLId\", str)\nGitHubNumber = NewType(\"GitHubNumber\", int)\nGitObjectID = NewType(\"GitObjectID\", str)\n\n# https://stackoverflow.com/a/55250601\nSetDefaultBranchInput = TypedDict(\n \"SetDefaultBranchInput\",\n {\n \"name\": str,\n \"default_branch\": str,\n },\n)\n\nUpdatePullRequestInput = TypedDict(\n \"UpdatePullRequestInput\",\n {\n \"base\": Optional[str],\n \"title\": Optional[str],\n \"body\": Optional[str],\n },\n)\n\nCreatePullRequestInput = TypedDict(\n \"CreatePullRequestInput\",\n {\n \"base\": str,\n \"head\": str,\n \"title\": str,\n \"body\": str,\n \"maintainer_can_modify\": bool,\n },\n)\n\nCreateIssueCommentInput = TypedDict(\n \"CreateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreateIssueCommentPayload = TypedDict(\n \"CreateIssueCommentPayload\",\n {\n \"id\": int,\n },\n)\n\nUpdateIssueCommentInput = TypedDict(\n \"UpdateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreatePullRequestPayload = TypedDict(\n \"CreatePullRequestPayload\",\n {\n \"number\": int,\n },\n)\n\n\n# The \"database\" for our mock instance\nclass GitHubState:\n repositories: Dict[GraphQLId, \"Repository\"]\n pull_requests: Dict[GraphQLId, \"PullRequest\"]\n # This is very inefficient but whatever\n issue_comments: Dict[GraphQLId, \"IssueComment\"]\n _next_id: int\n # These are indexed by repo id\n _next_pull_request_number: Dict[GraphQLId, int]\n _next_issue_comment_full_database_id: Dict[GraphQLId, int]\n root: \"Root\"\n upstream_sh: Optional[ghstack.shell.Shell]\n\n def repository(self, owner: str, name: str) -> \"Repository\":\n nameWithOwner = \"{}/{}\".format(owner, name)\n for r in self.repositories.values():\n if r.nameWithOwner == nameWithOwner:\n return r\n raise RuntimeError(\"unknown repository {}\".format(nameWithOwner))\n\n def pull_request(self, repo: \"Repository\", number: GitHubNumber) -> \"PullRequest\":\n for pr in self.pull_requests.values():\n if repo.id == pr._repository and pr.number == number:\n return pr\n raise RuntimeError(\n \"unrecognized pull request #{} in repository {}\".format(\n number, repo.nameWithOwner\n )\n )\n\n def issue_comment(self, repo: \"Repository\", comment_id: int) -> \"IssueComment\":\n for comment in self.issue_comments.values():\n if repo.id == comment._repository and comment.fullDatabaseId == comment_id:\n return comment\n raise RuntimeError(\n f\"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}\"\n )\n\n def next_id(self) -> GraphQLId:\n r = GraphQLId(str(self._next_id))\n self._next_id += 1\n return r\n\n def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber:\n r = GitHubNumber(self._next_pull_request_number[repo_id])\n self._next_pull_request_number[repo_id] += 1\n return r\n\n def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int:\n r = self._next_issue_comment_full_database_id[repo_id]\n self._next_issue_comment_full_database_id[repo_id] += 1\n return r\n\n def push_hook(self, refs: Sequence[str]) -> None:\n # updated_refs = set(refs)\n # for pr in self.pull_requests:\n # # TODO: this assumes only origin repository\n # # if pr.headRefName in updated_refs:\n # # pr.headRef =\n # pass\n pass\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n repo = self.repository(pr_resolved.owner, pr_resolved.repo)\n pr = self.pull_request(repo, GitHubNumber(pr_resolved.number))\n pr.closed = True\n # TODO: model merged too\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:\n self.repositories = {}\n self.pull_requests = {}\n self.issue_comments = {}\n self._next_id = 5000\n self._next_pull_request_number = {}\n self._next_issue_comment_full_database_id = {}\n self.root = Root()\n\n # Populate it with the most important repo ;)\n repo = Repository(\n id=GraphQLId(\"1000\"),\n name=\"pytorch\",\n nameWithOwner=\"pytorch/pytorch\",\n isFork=False,\n defaultBranchRef=None,\n )\n self.repositories[GraphQLId(\"1000\")] = repo\n self._next_pull_request_number[GraphQLId(\"1000\")] = 500\n self._next_issue_comment_full_database_id[GraphQLId(\"1000\")] = 1500\n\n self.upstream_sh = upstream_sh\n if self.upstream_sh is not None:\n # Setup upstream Git repository representing the\n # pytorch/pytorch repository in the directory specified\n # by upstream_sh. This is useful because some GitHub API\n # operations depend on repository state (e.g., what\n # the headRef is at the time a PR is created), so\n # we need this information\n self.upstream_sh.git(\"init\", \"--bare\", \"-b\", \"master\")\n tree = self.upstream_sh.git(\"write-tree\")\n commit = self.upstream_sh.git(\"commit-tree\", tree, input=\"Initial commit\")\n self.upstream_sh.git(\"branch\", \"-f\", \"master\", commit)\n\n # We only update this when a PATCH changes the default\n # branch; hopefully that's fine? In any case, it should\n # work for now since currently we only ever access the name\n # of the default branch rather than other parts of its ref.\n repo.defaultBranchRef = repo._make_ref(self, \"master\")\n\n\n@dataclass\nclass Node:\n id: GraphQLId\n\n\nGraphQLResolveInfo = Any # for now\n\n\ndef github_state(info: GraphQLResolveInfo) -> GitHubState:\n context = info.context\n assert isinstance(context, GitHubState)\n return context\n\n\n@dataclass\nclass Repository(Node):\n name: str\n nameWithOwner: str\n isFork: bool\n defaultBranchRef: Optional[\"Ref\"]\n\n def pullRequest(\n self, info: GraphQLResolveInfo, number: GitHubNumber\n ) -> \"PullRequest\":\n return github_state(info).pull_request(self, number)\n\n def pullRequests(self, info: GraphQLResolveInfo) -> \"PullRequestConnection\":\n return PullRequestConnection(\n nodes=list(\n filter(\n lambda pr: self == pr.repository(info),\n github_state(info).pull_requests.values(),\n )\n )\n )\n\n # TODO: This should take which repository the ref is in\n # This only works if you have upstream_sh\n def _make_ref(self, state: GitHubState, refName: str) -> \"Ref\":\n # TODO: Probably should preserve object identity here when\n # you call this with refName/oid that are the same\n assert state.upstream_sh\n gitObject = GitObject(\n id=state.next_id(),\n # TODO: this upstream_sh hardcode wrong, but ok for now\n # because we only have one repo\n oid=GitObjectID(state.upstream_sh.git(\"rev-parse\", refName)),\n _repository=self.id,\n )\n ref = Ref(\n id=state.next_id(),\n name=refName,\n _repository=self.id,\n target=gitObject,\n )\n return ref\n\n\n@dataclass\nclass GitObject(Node):\n oid: GitObjectID\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass Ref(Node):\n name: str\n _repository: GraphQLId\n target: GitObject\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequest(Node):\n baseRef: Optional[Ref]\n baseRefName: str\n body: str\n closed: bool\n headRef: Optional[Ref]\n headRefName: str\n # headRepository: Optional[Repository]\n # maintainerCanModify: bool\n number: GitHubNumber\n _repository: GraphQLId # cycle breaker\n # state: PullRequestState\n title: str\n url: str\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass IssueComment(Node):\n body: str\n fullDatabaseId: int\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequestConnection:\n nodes: List[PullRequest]\n\n\nclass Root:\n def repository(self, info: GraphQLResolveInfo, owner: str, name: str) -> Repository:\n return github_state(info).repository(owner, name)\n\n def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node:\n if id in github_state(info).repositories:\n return github_state(info).repositories[id]\n elif id in github_state(info).pull_requests:\n return github_state(info).pull_requests[id]\n elif id in github_state(info).issue_comments:\n return github_state(info).issue_comments[id]\n else:\n raise RuntimeError(\"unknown id {}\".format(id))\n\n\nwith open(\n os.path.join(os.path.dirname(__file__), \"github_schema.graphql\"), encoding=\"utf-8\"\n) as f:\n GITHUB_SCHEMA = graphql.build_schema(f.read())\n\n\n# Ummm. I thought there would be a way to stick these on the objects\n# themselves (in the same way resolvers can be put on resolvers) but\n# after a quick read of default_resolve_type_fn it doesn't look like\n# we ever actually look to value for type of information. This is\n# pretty clunky lol.\ndef set_is_type_of(name: str, cls: Any) -> None:\n # Can't use a type ignore on the next line because fbcode\n # and us don't agree that it's necessary hmm.\n o: Any = GITHUB_SCHEMA.get_type(name)\n o.is_type_of = lambda obj, info: isinstance(obj, cls)\n\n\nset_is_type_of(\"Repository\", Repository)\nset_is_type_of(\"PullRequest\", PullRequest)\nset_is_type_of(\"IssueComment\", IssueComment)\n\n\nclass FakeGitHubEndpoint(ghstack.github.GitHubEndpoint):\n state: GitHubState\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell] = None) -> None:\n self.state = GitHubState(upstream_sh)\n\n def graphql(self, query: str, **kwargs: Any) -> Any:\n r = graphql.graphql_sync(\n schema=GITHUB_SCHEMA,\n source=query,\n root_value=self.state.root,\n context_value=self.state,\n variable_values=kwargs,\n )\n if r.errors:\n # The GraphQL implementation loses all the stack traces!!!\n # D: You can 'recover' them by deleting the\n # 'except Exception as error' from GraphQL-core-next; need\n # to file a bug report\n raise RuntimeError(\n \"GraphQL query failed with errors:\\n\\n{}\".format(\n \"\\n\".join(str(e) for e in r.errors)\n )\n )\n # The top-level object isn't indexable by strings, but\n # everything underneath is, oddly enough\n return {\"data\": r.data}\n\n def push_hook(self, refNames: Sequence[str]) -> None:\n self.state.push_hook(refNames)\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n self.state.notify_merged(pr_resolved)\n\n def _create_pull(\n self, owner: str, name: str, input: CreatePullRequestInput\n ) -> CreatePullRequestPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n number = state.next_pull_request_number(repo.id)\n baseRef = None\n headRef = None\n # TODO: When we support forks, this needs rewriting to stop\n # hard coded the repo we opened the pull request on\n if state.upstream_sh:\n baseRef = repo._make_ref(state, input[\"base\"])\n headRef = repo._make_ref(state, input[\"head\"])\n pr = PullRequest(\n id=id,\n _repository=repo.id,\n number=number,\n closed=False,\n url=\"https://github.com/{}/pull/{}\".format(repo.nameWithOwner, number),\n baseRef=baseRef,\n baseRefName=input[\"base\"],\n headRef=headRef,\n headRefName=input[\"head\"],\n title=input[\"title\"],\n body=input[\"body\"],\n )\n # TODO: compute files changed\n state.pull_requests[id] = pr\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"number\": number,\n }\n\n # NB: This technically does have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _update_pull(\n self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n pr = state.pull_request(repo, number)\n # If I say input.get('title') is not None, mypy\n # is unable to infer input['title'] is not None\n if \"title\" in input and input[\"title\"] is not None:\n pr.title = input[\"title\"]\n if \"base\" in input and input[\"base\"] is not None:\n pr.baseRefName = input[\"base\"]\n pr.baseRef = repo._make_ref(state, pr.baseRefName)\n if \"body\" in input and input[\"body\"] is not None:\n pr.body = input[\"body\"]\n\n def _create_issue_comment(\n self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput\n ) -> CreateIssueCommentPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n comment_id = state.next_issue_comment_full_database_id(repo.id)\n comment = IssueComment(\n id=id,\n _repository=repo.id,\n fullDatabaseId=comment_id,\n body=input[\"body\"],\n )\n state.issue_comments[id] = comment\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"id\": comment_id,\n }\n\n def _update_issue_comment(\n self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n comment = state.issue_comment(repo, comment_id)\n if (r := input.get(\"body\")) is not None:\n comment.body = r\n\n # NB: This may have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _set_default_branch(\n self, owner: str, name: str, input: SetDefaultBranchInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n repo.defaultBranchRef = repo._make_ref(state, input[\"default_branch\"])\n\n def rest(self, method: str, path: str, **kwargs: Any) -> Any:\n if method == \"get\":\n m = re.match(r\"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection\", path)\n if m:\n # For now, pretend all branches are not protected\n raise ghstack.github.NotFoundError()\n\n elif method == \"post\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls$\", path):\n return self._create_pull(\n m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments\", path):\n return self._create_issue_comment(\n m.group(1),\n m.group(2),\n GitHubNumber(int(m.group(3))),\n cast(CreateIssueCommentInput, kwargs),\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers\", path):\n # Handle adding reviewers - just return success for testing\n return {}\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels\", path):\n # Handle adding labels - just return success for testing\n return {}\n elif method == \"patch\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$\", path):\n owner, name, number = m.groups()\n if number is not None:\n return self._update_pull(\n owner,\n name,\n GitHubNumber(int(number)),\n cast(UpdatePullRequestInput, kwargs),\n )\n elif \"default_branch\" in kwargs:\n return self._set_default_branch(\n owner, name, cast(SetDefaultBranchInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$\", path):\n return self._update_issue_comment(\n m.group(1),\n m.group(2),\n int(m.group(3)),\n cast(UpdateIssueCommentInput, kwargs),\n )\n raise NotImplementedError(\n \"FakeGitHubEndpoint REST {} {} not implemented\".format(method.upper(), path)\n )\n","replacement":"#!/usr/bin/env python3\n\nimport os.path\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, cast, Dict, List, NewType, Optional, Sequence\n\nimport graphql\nfrom typing_extensions import TypedDict\n\nimport ghstack.diff\nimport ghstack.github\nimport ghstack.shell\n\nGraphQLId = NewType(\"GraphQLId\", str)\nGitHubNumber = NewType(\"GitHubNumber\", int)\nGitObjectID = NewType(\"GitObjectID\", str)\n\n# https://stackoverflow.com/a/55250601\nSetDefaultBranchInput = TypedDict(\n \"SetDefaultBranchInput\",\n {\n \"name\": str,\n \"default_branch\": str,\n },\n)\n\nUpdatePullRequestInput = TypedDict(\n \"UpdatePullRequestInput\",\n {\n \"base\": Optional[str],\n \"title\": Optional[str],\n \"body\": Optional[str],\n },\n)\n\nCreatePullRequestInput = TypedDict(\n \"CreatePullRequestInput\",\n {\n \"base\": str,\n \"head\": str,\n \"title\": str,\n \"body\": str,\n \"maintainer_can_modify\": bool,\n },\n)\n\nCreateIssueCommentInput = TypedDict(\n \"CreateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreateIssueCommentPayload = TypedDict(\n \"CreateIssueCommentPayload\",\n {\n \"id\": int,\n },\n)\n\nUpdateIssueCommentInput = TypedDict(\n \"UpdateIssueCommentInput\",\n {\"body\": str},\n)\n\nCreatePullRequestPayload = TypedDict(\n \"CreatePullRequestPayload\",\n {\n \"number\": int,\n },\n)\n\n\n# The \"database\" for our mock instance\nclass GitHubState:\n repositories: Dict[GraphQLId, \"Repository\"]\n pull_requests: Dict[GraphQLId, \"PullRequest\"]\n # This is very inefficient but whatever\n issue_comments: Dict[GraphQLId, \"IssueComment\"]\n _next_id: int\n # These are indexed by repo id\n _next_pull_request_number: Dict[GraphQLId, int]\n _next_issue_comment_full_database_id: Dict[GraphQLId, int]\n root: \"Root\"\n upstream_sh: Optional[ghstack.shell.Shell]\n\n def repository(self, owner: str, name: str) -> \"Repository\":\n nameWithOwner = \"{}/{}\".format(owner, name)\n for r in self.repositories.values():\n if r.nameWithOwner == nameWithOwner:\n return r\n raise RuntimeError(\"unknown repository {}\".format(nameWithOwner))\n\n def pull_request(self, repo: \"Repository\", number: GitHubNumber) -> \"PullRequest\":\n for pr in self.pull_requests.values():\n if repo.id == pr._repository and pr.number == number:\n return pr\n raise RuntimeError(\n \"unrecognized pull request #{} in repository {}\".format(\n number, repo.nameWithOwner\n )\n )\n\n def issue_comment(self, repo: \"Repository\", comment_id: int) -> \"IssueComment\":\n for comment in self.issue_comments.values():\n if repo.id == comment._repository and comment.fullDatabaseId == comment_id:\n return comment\n raise RuntimeError(\n f\"unrecognized issue comment {comment_id} in repository {repo.nameWithOwner}\"\n )\n\n def next_id(self) -> GraphQLId:\n r = GraphQLId(str(self._next_id))\n self._next_id += 1\n return r\n\n def next_pull_request_number(self, repo_id: GraphQLId) -> GitHubNumber:\n r = GitHubNumber(self._next_pull_request_number[repo_id])\n self._next_pull_request_number[repo_id] += 1\n return r\n\n def next_issue_comment_full_database_id(self, repo_id: GraphQLId) -> int:\n r = self._next_issue_comment_full_database_id[repo_id]\n self._next_issue_comment_full_database_id[repo_id] += 1\n return r\n\n def push_hook(self, refs: Sequence[str]) -> None:\n # updated_refs = set(refs)\n # for pr in self.pull_requests:\n # # TODO: this assumes only origin repository\n # # if pr.headRefName in updated_refs:\n # # pr.headRef =\n # pass\n pass\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n repo = self.repository(pr_resolved.owner, pr_resolved.repo)\n pr = self.pull_request(repo, GitHubNumber(pr_resolved.number))\n pr.closed = True\n # TODO: model merged too\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell]) -> None:\n self.repositories = {}\n self.pull_requests = {}\n self.issue_comments = {}\n self._next_id = 5000\n self._next_pull_request_number = {}\n self._next_issue_comment_full_database_id = {}\n self.root = Root()\n\n # Populate it with the most important repo ;)\n repo = Repository(\n id=GraphQLId(\"1000\"),\n name=\"pytorch\",\n nameWithOwner=\"pytorch/pytorch\",\n isFork=False,\n defaultBranchRef=None,\n )\n self.repositories[GraphQLId(\"1000\")] = repo\n self._next_pull_request_number[GraphQLId(\"1000\")] = 500\n self._next_issue_comment_full_database_id[GraphQLId(\"1000\")] = 1500\n\n self.upstream_sh = upstream_sh\n if self.upstream_sh is not None:\n # Setup upstream Git repository representing the\n # pytorch/pytorch repository in the directory specified\n # by upstream_sh. This is useful because some GitHub API\n # operations depend on repository state (e.g., what\n # the headRef is at the time a PR is created), so\n # we need this information\n self.upstream_sh.git(\"init\", \"--bare\", \"-b\", \"master\")\n tree = self.upstream_sh.git(\"write-tree\")\n commit = self.upstream_sh.git(\"commit-tree\", tree, input=\"Initial commit\")\n self.upstream_sh.git(\"branch\", \"-f\", \"master\", commit)\n\n # We only update this when a PATCH changes the default\n # branch; hopefully that's fine? In any case, it should\n # work for now since currently we only ever access the name\n # of the default branch rather than other parts of its ref.\n repo.defaultBranchRef = repo._make_ref(self, \"master\")\n\n\n@dataclass\nclass Node:\n id: GraphQLId\n\n\nGraphQLResolveInfo = Any # for now\n\n\ndef github_state(info: GraphQLResolveInfo) -> GitHubState:\n context = info.context\n assert isinstance(context, GitHubState)\n return context\n\n\n@dataclass\nclass Repository(Node):\n name: str\n nameWithOwner: str\n isFork: bool\n defaultBranchRef: Optional[\"Ref\"]\n\n def pullRequest(\n self, info: GraphQLResolveInfo, number: GitHubNumber\n ) -> \"PullRequest\":\n return github_state(info).pull_request(self, number)\n\n def pullRequests(self, info: GraphQLResolveInfo) -> \"PullRequestConnection\":\n return PullRequestConnection(\n nodes=list(\n filter(\n lambda pr: self == pr.repository(info),\n github_state(info).pull_requests.values(),\n )\n )\n )\n\n # TODO: This should take which repository the ref is in\n # This only works if you have upstream_sh\n def _make_ref(self, state: GitHubState, refName: str) -> \"Ref\":\n # TODO: Probably should preserve object identity here when\n # you call this with refName/oid that are the same\n assert state.upstream_sh\n gitObject = GitObject(\n id=state.next_id(),\n # TODO: this upstream_sh hardcode wrong, but ok for now\n # because we only have one repo\n oid=GitObjectID(state.upstream_sh.git(\"rev-parse\", refName)),\n _repository=self.id,\n )\n ref = Ref(\n id=state.next_id(),\n name=refName,\n _repository=self.id,\n target=gitObject,\n )\n return ref\n\n\n@dataclass\nclass GitObject(Node):\n oid: GitObjectID\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass Ref(Node):\n name: str\n _repository: GraphQLId\n target: GitObject\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequest(Node):\n baseRef: Optional[Ref]\n baseRefName: str\n body: str\n closed: bool\n headRef: Optional[Ref]\n headRefName: str\n # headRepository: Optional[Repository]\n # maintainerCanModify: bool\n number: GitHubNumber\n _repository: GraphQLId # cycle breaker\n # state: PullRequestState\n title: str\n url: str\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass IssueComment(Node):\n body: str\n fullDatabaseId: int\n _repository: GraphQLId\n\n def repository(self, info: GraphQLResolveInfo) -> Repository:\n return github_state(info).repositories[self._repository]\n\n\n@dataclass\nclass PullRequestConnection:\n nodes: List[PullRequest]\n\n\nclass Root:\n def repository(self, info: GraphQLResolveInfo, owner: str, name: str) -> Repository:\n return github_state(info).repository(owner, name)\n\n def node(self, info: GraphQLResolveInfo, id: GraphQLId) -> Node:\n if id in github_state(info).repositories:\n return github_state(info).repositories[id]\n elif id in github_state(info).pull_requests:\n return github_state(info).pull_requests[id]\n elif id in github_state(info).issue_comments:\n return github_state(info).issue_comments[id]\n else:\n raise RuntimeError(\"unknown id {}\".format(id))\n\n\nwith open(\n os.path.join(os.path.dirname(__file__), \"github_schema.graphql\"), encoding=\"utf-8\"\n) as f:\n GITHUB_SCHEMA = graphql.build_schema(f.read())\n\n\n# Ummm. I thought there would be a way to stick these on the objects\n# themselves (in the same way resolvers can be put on resolvers) but\n# after a quick read of default_resolve_type_fn it doesn't look like\n# we ever actually look to value for type of information. This is\n# pretty clunky lol.\ndef set_is_type_of(name: str, cls: Any) -> None:\n # Can't use a type ignore on the next line because fbcode\n # and us don't agree that it's necessary hmm.\n o: Any = GITHUB_SCHEMA.get_type(name)\n o.is_type_of = lambda obj, info: isinstance(obj, cls)\n\n\nset_is_type_of(\"Repository\", Repository)\nset_is_type_of(\"PullRequest\", PullRequest)\nset_is_type_of(\"IssueComment\", IssueComment)\n\n\nclass FakeGitHubEndpoint(ghstack.github.GitHubEndpoint):\n state: GitHubState\n\n def __init__(self, upstream_sh: Optional[ghstack.shell.Shell] = None) -> None:\n self.state = GitHubState(upstream_sh)\n\n def graphql(self, query: str, **kwargs: Any) -> Any:\n r = graphql.graphql_sync(\n schema=GITHUB_SCHEMA,\n source=query,\n root_value=self.state.root,\n context_value=self.state,\n variable_values=kwargs,\n )\n if r.errors:\n # The GraphQL implementation loses all the stack traces!!!\n # D: You can 'recover' them by deleting the\n # 'except Exception as error' from GraphQL-core-next; need\n # to file a bug report\n raise RuntimeError(\n \"GraphQL query failed with errors:\\n\\n{}\".format(\n \"\\n\".join(str(e) for e in r.errors)\n )\n )\n # The top-level object isn't indexable by strings, but\n # everything underneath is, oddly enough\n return {\"data\": r.data}\n\n def push_hook(self, refNames: Sequence[str]) -> None:\n self.state.push_hook(refNames)\n\n def notify_merged(self, pr_resolved: ghstack.diff.PullRequestResolved) -> None:\n self.state.notify_merged(pr_resolved)\n\n def _create_pull(\n self, owner: str, name: str, input: CreatePullRequestInput\n ) -> CreatePullRequestPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n number = state.next_pull_request_number(repo.id)\n baseRef = None\n headRef = None\n # TODO: When we support forks, this needs rewriting to stop\n # hard coded the repo we opened the pull request on\n if state.upstream_sh:\n baseRef = repo._make_ref(state, input[\"base\"])\n headRef = repo._make_ref(state, input[\"head\"])\n pr = PullRequest(\n id=id,\n _repository=repo.id,\n number=number,\n closed=False,\n url=\"https://github.com/{}/pull/{}\".format(repo.nameWithOwner, number),\n baseRef=baseRef,\n baseRefName=input[\"base\"],\n headRef=headRef,\n headRefName=input[\"head\"],\n title=input[\"title\"],\n body=input[\"body\"],\n )\n # TODO: compute files changed\n state.pull_requests[id] = pr\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"number\": number,\n }\n\n # NB: This technically does have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _update_pull(\n self, owner: str, name: str, number: GitHubNumber, input: UpdatePullRequestInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n pr = state.pull_request(repo, number)\n # If I say input.get('title') is not None, mypy\n # is unable to infer input['title'] is not None\n if \"title\" in input and input[\"title\"] is not None:\n pr.title = input[\"title\"]\n if \"base\" in input and input[\"base\"] is not None:\n pr.baseRefName = input[\"base\"]\n pr.baseRef = repo._make_ref(state, pr.baseRefName)\n if \"body\" in input and input[\"body\"] is not None:\n pr.body = input[\"body\"]\n\n def _create_issue_comment(\n self, owner: str, name: str, comment_id: int, input: CreateIssueCommentInput\n ) -> CreateIssueCommentPayload:\n state = self.state\n id = state.next_id()\n repo = state.repository(owner, name)\n comment_id = state.next_issue_comment_full_database_id(repo.id)\n comment = IssueComment(\n id=id,\n _repository=repo.id,\n fullDatabaseId=comment_id,\n body=input[\"body\"],\n )\n state.issue_comments[id] = comment\n # This is only a subset of what the actual REST endpoint\n # returns.\n return {\n \"id\": comment_id,\n }\n\n def _update_issue_comment(\n self, owner: str, name: str, comment_id: int, input: UpdateIssueCommentInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n comment = state.issue_comment(repo, comment_id)\n if (r := input.get(\"body\")) is not None:\n comment.body = r\n\n # NB: This may have a payload, but we don't\n # use it so I didn't bother constructing it.\n def _set_default_branch(\n self, owner: str, name: str, input: SetDefaultBranchInput\n ) -> None:\n state = self.state\n repo = state.repository(owner, name)\n repo.defaultBranchRef = repo._make_ref(state, input[\"default_branch\"])\n\n def rest(self, method: str, path: str, **kwargs: Any) -> Any:\n if method == \"get\":\n m = re.match(r\"^repos/([^/]+)/([^/]+)/branches/([^/]+)/protection\", path)\n if m:\n # For now, pretend all branches are not protected\n raise ghstack.github.NotFoundError()\n\n elif method == \"post\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/pulls$\", path):\n return self._create_pull(\n m.group(1), m.group(2), cast(CreatePullRequestInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/comments\", path):\n return self._create_issue_comment(\n m.group(1),\n m.group(2),\n GitHubNumber(int(m.group(3))),\n cast(CreateIssueCommentInput, kwargs),\n )\n if m := re.match(\n r\"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers\", path\n ):\n # Handle adding reviewers - just return success for testing\n return {}\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels\", path):\n # Handle adding labels - just return success for testing\n return {}\n elif method == \"patch\":\n if m := re.match(r\"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$\", path):\n owner, name, number = m.groups()\n if number is not None:\n return self._update_pull(\n owner,\n name,\n GitHubNumber(int(number)),\n cast(UpdatePullRequestInput, kwargs),\n )\n elif \"default_branch\" in kwargs:\n return self._set_default_branch(\n owner, name, cast(SetDefaultBranchInput, kwargs)\n )\n if m := re.match(r\"^repos/([^/]+)/([^/]+)/issues/comments/([^/]+)$\", path):\n return self._update_issue_comment(\n m.group(1),\n m.group(2),\n int(m.group(3)),\n cast(UpdateIssueCommentInput, kwargs),\n )\n raise NotImplementedError(\n \"FakeGitHubEndpoint REST {} {} not implemented\".format(method.upper(), path)\n )\n"} diff --git a/lint.json b/lint.json deleted file mode 100644 index 59ab1d7..0000000 --- a/lint.json +++ /dev/null @@ -1 +0,0 @@ -{"path":"/home/oulgen/dev/ghstack/src/ghstack/submit.py","line":1458,"char":37,"code":"FLAKE8","severity":"warning","name":"E741","description":"ambiguous variable name 'l'\nSee https://www.flake8rules.com/rules/E741.html"} From 065600aacce95870cf80242a3784c914c67ac6c6 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Mon, 22 Dec 2025 22:30:34 -0800 Subject: [PATCH 3/3] improve tests --- src/ghstack/github_fake.py | 17 ++++++++++++-- src/ghstack/test_prelude.py | 26 ++++++++++++++++++---- test/submit/cli_reviewer_and_label.py.test | 12 +++++++++- test/submit/reviewer_and_label.py.test | 8 ++++++- 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index 2b25d8e..622b6af 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import dataclasses import os.path import re from dataclasses import dataclass @@ -271,6 +272,8 @@ class PullRequest(Node): # state: PullRequestState title: str url: str + reviewers: List[str] = dataclasses.field(default_factory=list) + labels: List[str] = dataclasses.field(default_factory=list) def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] @@ -476,10 +479,20 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: if m := re.match( r"^repos/([^/]+)/([^/]+)/pulls/([^/]+)/requested_reviewers", path ): - # Handle adding reviewers - just return success for testing + # Handle adding reviewers + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + reviewers = kwargs.get("reviewers", []) + pr.reviewers.extend(reviewers) return {} if m := re.match(r"^repos/([^/]+)/([^/]+)/issues/([^/]+)/labels", path): - # Handle adding labels - just return success for testing + # Handle adding labels + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + labels = kwargs.get("labels", []) + pr.labels.extend(labels) return {} elif method == "patch": if m := re.match(r"^repos/([^/]+)/([^/]+)(?:/pulls/([^/]+))?$", path): diff --git a/src/ghstack/test_prelude.py b/src/ghstack/test_prelude.py index 0861db8..567fadb 100644 --- a/src/ghstack/test_prelude.py +++ b/src/ghstack/test_prelude.py @@ -51,6 +51,8 @@ "get_sh", "get_upstream_sh", "get_github", + "get_pr_reviewers", + "get_pr_labels", "tick", "captured_output", ] @@ -390,6 +392,26 @@ def is_direct() -> bool: return CTX.direct +def get_github() -> ghstack.github_fake.FakeGitHubEndpoint: + return CTX.github + + +def get_pr_reviewers(pr_number: int) -> List[str]: + """Get the reviewers for a PR number.""" + github = get_github() + repo = github.state.repository("pytorch", "pytorch") + pr = github.state.pull_request(repo, ghstack.github_fake.GitHubNumber(pr_number)) + return pr.reviewers + + +def get_pr_labels(pr_number: int) -> List[str]: + """Get the labels for a PR number.""" + github = get_github() + repo = github.state.repository("pytorch", "pytorch") + pr = github.state.pull_request(repo, ghstack.github_fake.GitHubNumber(pr_number)) + return pr.labels + + def assert_eq(a: Any, b: Any) -> None: assert a == b, f"{a} != {b}" @@ -423,9 +445,5 @@ def get_upstream_sh() -> ghstack.shell.Shell: return CTX.upstream_sh -def get_github() -> ghstack.github.GitHubEndpoint: - return CTX.github - - def tick() -> None: CTX.sh.test_tick() diff --git a/test/submit/cli_reviewer_and_label.py.test b/test/submit/cli_reviewer_and_label.py.test index 50964a7..117ba2b 100644 --- a/test/submit/cli_reviewer_and_label.py.test +++ b/test/submit/cli_reviewer_and_label.py.test @@ -6,9 +6,19 @@ init_test() commit("A") (A,) = gh_submit("Initial commit", reviewer="reviewer1", label="bug") +# Verify first PR has correct reviewers and labels +assert_eq(get_pr_reviewers(500), ["reviewer1"]) +assert_eq(get_pr_labels(500), ["bug"]) + # Create second commit with different reviewers/labels commit("B") -(A2, B) = gh_submit("Add B", reviewer="reviewer2,reviewer3", label="enhancement,priority-high") +(A2, B) = gh_submit( + "Add B", reviewer="reviewer2,reviewer3", label="enhancement,priority-high" +) + +# Verify second PR has correct reviewers and labels +assert_eq(get_pr_reviewers(501), ["reviewer2", "reviewer3"]) +assert_eq(get_pr_labels(501), ["enhancement", "priority-high"]) if is_direct(): assert_github_state( diff --git a/test/submit/reviewer_and_label.py.test b/test/submit/reviewer_and_label.py.test index 392de7a..57fe7f4 100644 --- a/test/submit/reviewer_and_label.py.test +++ b/test/submit/reviewer_and_label.py.test @@ -3,7 +3,13 @@ from ghstack.test_prelude import * init_test() commit("A") -(A,) = gh_submit("Initial commit", reviewer="reviewer1,reviewer2", label="bug,enhancement") +(A,) = gh_submit( + "Initial commit", reviewer="reviewer1,reviewer2", label="bug,enhancement" +) + +# Verify reviewers and labels were added +assert_eq(get_pr_reviewers(500), ["reviewer1", "reviewer2"]) +assert_eq(get_pr_labels(500), ["bug", "enhancement"]) if is_direct(): assert_github_state(