diff --git a/src/ghstack/config.py b/src/ghstack/config.py index 72a91d0..ff78fe9 100644 --- a/src/ghstack/config.py +++ b/src/ghstack/config.py @@ -5,8 +5,10 @@ import logging import os import re +import shutil +import subprocess from pathlib import Path -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Tuple import requests @@ -15,6 +17,71 @@ DEFAULT_GHSTACKRC_PATH = Path.home() / ".ghstackrc" GHSTACKRC_PATH_VAR = "GHSTACKRC_PATH" + +def is_gh_cli_available() -> bool: + """Check if the GitHub CLI (gh) is available in PATH.""" + return shutil.which("gh") is not None + + +def get_gh_cli_credentials( + github_url: str = "github.com", +) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """ + Extract credentials from the GitHub CLI if available and authenticated. + + Args: + github_url: The GitHub host to get credentials for. + + Returns: + A tuple of (token, username, url) or (None, None, None) if unavailable. + """ + if not is_gh_cli_available(): + return None, None, None + + try: + # Check if gh is authenticated for this host + auth_status = subprocess.run( + ["gh", "auth", "status", "-h", github_url], + capture_output=True, + text=True, + ) + if auth_status.returncode != 0: + logging.debug(f"gh CLI not authenticated for {github_url}") + return None, None, None + + # Get the token + token_result = subprocess.run( + ["gh", "auth", "token", "-h", github_url], + capture_output=True, + text=True, + ) + if token_result.returncode != 0: + logging.debug("Failed to get token from gh CLI") + return None, None, None + token = token_result.stdout.strip() + if not token: + return None, None, None + + # Get the username using gh api + username_result = subprocess.run( + ["gh", "api", "user", "-q", ".login", "--hostname", github_url], + capture_output=True, + text=True, + ) + username = None + if username_result.returncode == 0: + username = username_result.stdout.strip() + + logging.debug( + f"Successfully retrieved credentials from gh CLI for {github_url}" + ) + return token, username, github_url + + except Exception as e: + logging.debug(f"Error getting credentials from gh CLI: {e}") + return None, None, None + + Config = NamedTuple( "Config", [ @@ -97,6 +164,7 @@ def read_config( # Environment variable overrides config file # This envvar is legacy from ghexport days github_oauth = os.getenv("OAUTH_TOKEN") + gh_cli_username = None # Track username from gh CLI if github_oauth is not None: logging.warning( "Deprecated OAUTH_TOKEN environment variable used to populate github_oauth--" @@ -105,6 +173,17 @@ def read_config( ) if github_oauth is None and config.has_option("ghstack", "github_oauth"): github_oauth = config.get("ghstack", "github_oauth") + + # Try GitHub CLI if available and no token found yet + if github_oauth is None and request_github_token: + gh_token, gh_username, _ = get_gh_cli_credentials(github_url) + if gh_token is not None: + print(f"Using GitHub credentials from gh CLI for {github_url}") + github_oauth = gh_token + gh_cli_username = gh_username + # Don't save gh CLI credentials to config - they may change/expire + + # Fall back to device flow if still no token if github_oauth is None and request_github_token: print("Generating GitHub access token...") CLIENT_ID = "89cc88ca50efbe86907a" @@ -150,6 +229,11 @@ def read_config( github_username = None if config.has_option("ghstack", "github_username"): github_username = config.get("ghstack", "github_username") + # Use username from gh CLI if we got it + if github_username is None and gh_cli_username is not None: + github_username = gh_cli_username + # Don't save gh CLI username to config - it comes from gh CLI + # Fall back to API lookup if we have a token but no username yet if github_username is None and github_oauth is not None: request_url: str if github_url == "github.com": diff --git a/src/ghstack/github_utils.py b/src/ghstack/github_utils.py index aa186b3..5c967e6 100644 --- a/src/ghstack/github_utils.py +++ b/src/ghstack/github_utils.py @@ -138,12 +138,22 @@ def get_github_repo_info( ) +def _normalize_remote_url(remote_url: str) -> str: + """Convert SSH remote URL to HTTPS format, strip .git suffix.""" + # git@github.com:owner/repo.git -> https://github.com/owner/repo + m = re.match(r"^git@([^:]+):/?(.+?)(?:\.git)?$", remote_url) + if m: + return f"https://{m.group(1)}/{m.group(2)}" + return re.sub(r"\.git$", "", remote_url) + + def parse_pull_request( pull_request: str, *, sh: Optional[ghstack.shell.Shell] = None, remote_name: Optional[str] = None, ) -> GitHubPullRequestParams: + pull_request = pull_request.lstrip("#") m = RE_PR_URL.match(pull_request) if not m: # We can reconstruct the URL if just a PR number is passed @@ -151,7 +161,9 @@ def parse_pull_request( remote_url = sh.git("remote", "get-url", remote_name) # Do not pass the shell to avoid infinite loop try: - return parse_pull_request(remote_url + "/pull/" + pull_request) + return parse_pull_request( + _normalize_remote_url(remote_url) + "/pull/" + pull_request + ) except RuntimeError: # Fall back on original error message pass