diff --git a/src/ghstack/github_utils.py b/src/ghstack/github_utils.py index aa186b3..7d343d8 100644 --- a/src/ghstack/github_utils.py +++ b/src/ghstack/github_utils.py @@ -123,8 +123,19 @@ def get_github_repo_info( } +# Matches GitHub PR URLs like: +# https://github.com/owner/repo/pull/123 +# https://github.com/owner/repo/pull/123/ +# https://github.com/owner/repo/pull/123/files +# https://github.com/owner/repo/pull/123/commits RE_PR_URL = re.compile( - r"^https://(?P[^/]+)/(?P[^/]+)/(?P[^/]+)/pull/(?P[0-9]+)/?$" + r"^https://(?P[^/]+)/(?P[^/]+)/(?P[^/]+)/pull/(?P[0-9]+)(?:/.*)?$" +) + +# Matches PyTorch HUD URLs like: +# https://hud.pytorch.org/pr/169404 +RE_PYTORCH_HUD_URL = re.compile( + r"^https://hud\.pytorch\.org/pr/(?P[0-9]+)/?$" ) GitHubPullRequestParams = TypedDict( @@ -144,6 +155,17 @@ def parse_pull_request( sh: Optional[ghstack.shell.Shell] = None, remote_name: Optional[str] = None, ) -> GitHubPullRequestParams: + # Check for PyTorch HUD URL first (hud.pytorch.org/pr/NUMBER) + hud_match = RE_PYTORCH_HUD_URL.match(pull_request) + if hud_match: + number = int(hud_match.group("number")) + return { + "github_url": "github.com", + "owner": "pytorch", + "name": "pytorch", + "number": number, + } + m = RE_PR_URL.match(pull_request) if not m: # We can reconstruct the URL if just a PR number is passed diff --git a/test_github_utils.py b/test_github_utils.py new file mode 100644 index 0000000..f7d0917 --- /dev/null +++ b/test_github_utils.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +import unittest + +import ghstack.github_utils + + +class TestParsePullRequest(unittest.TestCase): + def test_github_url_basic(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/pytorch/pytorch/pull/169404" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_github_url_trailing_slash(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/pytorch/pytorch/pull/169404/" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_github_url_files_suffix(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/pytorch/pytorch/pull/169404/files" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_github_url_commits_suffix(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/pytorch/pytorch/pull/169404/commits" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_github_url_commits_with_sha(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/pytorch/pytorch/pull/169404/commits/abc123def" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_pytorch_hud_url_basic(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://hud.pytorch.org/pr/169404" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_pytorch_hud_url_trailing_slash(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://hud.pytorch.org/pr/169404/" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "pytorch") + self.assertEqual(result["name"], "pytorch") + self.assertEqual(result["number"], 169404) + + def test_different_owner_repo(self) -> None: + result = ghstack.github_utils.parse_pull_request( + "https://github.com/facebook/react/pull/12345" + ) + self.assertEqual(result["github_url"], "github.com") + self.assertEqual(result["owner"], "facebook") + self.assertEqual(result["name"], "react") + self.assertEqual(result["number"], 12345) + + def test_invalid_url_raises(self) -> None: + with self.assertRaises(RuntimeError): + ghstack.github_utils.parse_pull_request("not-a-valid-url") + + def test_invalid_hud_url_raises(self) -> None: + with self.assertRaises(RuntimeError): + ghstack.github_utils.parse_pull_request("https://hud.pytorch.org/not-pr/123") + + +if __name__ == "__main__": + unittest.main()