diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index f7957fa0d..247fc3d26 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -102,10 +102,11 @@ def make_cfapi_request( @lru_cache(maxsize=1) -def get_user_id(api_key: Optional[str] = None) -> Optional[str]: +def get_user_id(api_key: Optional[str] = None, *, suppress_errors: bool = False) -> Optional[str]: """Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint. :param api_key: The API key to use. If None, uses get_codeflash_api_key(). + :param suppress_errors: If True, avoid exiting on auth/version errors and return None instead. :return: The userid or None if the request fails. """ lsp_enabled = is_LSP_enabled() @@ -129,6 +130,9 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]: if min_version and version.parse(min_version) > version.parse(__version__): msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`." console.print(f"[bold red]{msg}[/bold red]") + if suppress_errors: + logger.debug(msg) + return None if lsp_enabled: logger.debug(msg) return f"Error: {msg}" @@ -140,6 +144,9 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]: if response.status_code == 403: error_title = "Invalid Codeflash API key. The API key you provided is not valid." + if suppress_errors: + logger.debug(error_title) + return None if lsp_enabled: return f"Error: {error_title}" msg = ( diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index cd4d94787..34b568665 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -289,11 +289,15 @@ def _handle_show_config() -> None: from codeflash.code_utils.config_parser import parse_config_file config, config_file_path = parse_config_file() - status = "Saved config" + is_file_backed_config = config_file_path.is_file() + status = "Saved config" if is_file_backed_config else "Auto-detected (zero-config)" console.print() console.print(f"[bold]Codeflash Configuration[/bold] ({status})") - console.print(f"[dim]Config file: {config_file_path}[/dim]") + if is_file_backed_config: + console.print(f"[dim]Config file: {config_file_path}[/dim]") + else: + console.print(f"[dim]Config source: {config_file_path}[/dim]") console.print() table = Table(show_header=True, header_style="bold cyan") diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index bd44cb761..40de9b6e6 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -44,7 +44,7 @@ from argparse import Namespace -def init_codeflash() -> None: +def init_codeflash(*, skip_confirm: bool = False, skip_api_key: bool = False) -> None: try: welcome_panel = Panel( Text( @@ -63,34 +63,46 @@ def init_codeflash() -> None: project_language = detect_project_language() if project_language == ProjectLanguage.GO: - init_go_project() + init_go_project(skip_confirm=skip_confirm, skip_api_key=skip_api_key) return if project_language == ProjectLanguage.JAVA: - init_java_project() + init_java_project(skip_confirm=skip_confirm, skip_api_key=skip_api_key) return if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT): - init_js_project(project_language) + init_js_project(project_language, skip_confirm=skip_confirm, skip_api_key=skip_api_key) return # Python project flow - did_add_new_key = prompt_api_key() - - should_modify, config = should_modify_pyproject_toml() + did_add_new_key = False if skip_api_key else prompt_api_key() + git_remote = "origin" + should_modify, config = should_modify_pyproject_toml(skip_confirm=skip_confirm) git_remote = config.get("git_remote", "origin") if config else "origin" if should_modify: - setup_info: CLISetupInfo = collect_setup_info() - git_remote = setup_info.git_remote - configured = configure_pyproject_toml(setup_info) - if not configured: - apologize_and_exit() + if skip_confirm: + from codeflash.setup import detect_project, write_config + + detected = detect_project() + configured, message = write_config(detected) + if configured: + click.echo(message) + click.echo() + else: + click.echo(message) + apologize_and_exit() + else: + setup_info = collect_setup_info() + git_remote = setup_info.git_remote + configured = configure_pyproject_toml(setup_info) + if not configured: + apologize_and_exit() install_github_app(git_remote) - install_github_actions(override_formatter_check=True) + install_github_actions(override_formatter_check=True, skip_confirm=skip_confirm) install_vscode_extension() diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 2baa88dcf..8761b3a59 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -41,8 +41,43 @@ DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG + +def _configure_stdio_for_unicode_safety() -> None: + """Avoid hard failures when the active console encoding can't represent Unicode.""" + for stream in (sys.stdout, sys.stderr): + reconfigure = getattr(stream, "reconfigure", None) + if callable(reconfigure): + with contextlib.suppress(OSError, ValueError): + reconfigure(errors="replace") + + +def _can_encode(text: str) -> bool: + encoding = getattr(sys.stdout, "encoding", None) or "utf-8" + try: + text.encode(encoding) + except UnicodeEncodeError: + return False + return True + + +_configure_stdio_for_unicode_safety() + console = Console(highlighter=NullHighlighter()) +_original_console_rule = console.rule + + +def _safe_console_rule(title: str = "", *args: object, **kwargs: object) -> None: + if "characters" not in kwargs and not _can_encode("─"): + kwargs["characters"] = "-" + if title and not _can_encode(title): + encoding = getattr(sys.stdout, "encoding", None) or "utf-8" + title = title.encode(encoding, errors="replace").decode(encoding, errors="replace") + _original_console_rule(title, *args, **kwargs) + + +console.rule = _safe_console_rule # type: ignore[method-assign] + if is_LSP_enabled() or is_subagent_mode(): console.quiet = True diff --git a/codeflash/cli_cmds/github_workflow.py b/codeflash/cli_cmds/github_workflow.py index ec2637bb5..4f9431657 100644 --- a/codeflash/cli_cmds/github_workflow.py +++ b/codeflash/cli_cmds/github_workflow.py @@ -34,7 +34,7 @@ class DependencyManager(Enum): UNKNOWN = auto() -def install_github_actions(override_formatter_check: bool = False) -> None: +def install_github_actions(override_formatter_check: bool = False, *, skip_confirm: bool = False) -> None: try: config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check) @@ -100,12 +100,15 @@ def install_github_actions(override_formatter_check: bool = False) -> None: console.print(benchmark_panel) console.print() - benchmark_questions = [ - inquirer.Confirm("benchmark_mode", message="Run GitHub Actions in benchmark mode?", default=True) - ] + if skip_confirm: + benchmark_mode = True + else: + benchmark_questions = [ + inquirer.Confirm("benchmark_mode", message="Run GitHub Actions in benchmark mode?", default=True) + ] - benchmark_answers = inquirer.prompt(benchmark_questions, theme=CodeflashTheme()) - benchmark_mode = benchmark_answers["benchmark_mode"] if benchmark_answers else False + benchmark_answers = inquirer.prompt(benchmark_questions, theme=CodeflashTheme()) + benchmark_mode = benchmark_answers["benchmark_mode"] if benchmark_answers else False # Show prompt only if workflow doesn't exist locally actions_panel = Panel( @@ -121,26 +124,28 @@ def install_github_actions(override_formatter_check: bool = False) -> None: console.print(actions_panel) console.print() - creation_questions = [ - inquirer.Confirm( - "confirm_creation", - message="Set up GitHub Actions for continuous optimization? We'll open a pull request with the workflow file.", - default=True, - ) - ] + if skip_confirm: + confirm_creation = True + else: + creation_questions = [ + inquirer.Confirm( + "confirm_creation", + message="Set up GitHub Actions for continuous optimization? We'll open a pull request with the workflow file.", + default=True, + ) + ] + + creation_answers = inquirer.prompt(creation_questions, theme=CodeflashTheme()) + confirm_creation = bool(creation_answers and creation_answers["confirm_creation"]) - creation_answers = inquirer.prompt(creation_questions, theme=CodeflashTheme()) - if not creation_answers or not creation_answers["confirm_creation"]: + if not confirm_creation: skip_panel = Panel( Text("⏩️ Skipping GitHub Actions setup.", style="yellow"), title="⏩️ Skipped", border_style="yellow" ) console.print(skip_panel) ph("cli-github-workflow-skipped") return - ph( - "cli-github-optimization-confirm-workflow-creation", - {"confirm_creation": creation_answers["confirm_creation"]}, - ) + ph("cli-github-optimization-confirm-workflow-creation", {"confirm_creation": confirm_creation}) # Generate workflow content AFTER user confirmation logger.info("[github_workflow.py:install_github_actions] User confirmed, generating workflow content...") @@ -423,6 +428,11 @@ def install_github_actions(override_formatter_check: bool = False) -> None: f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}" ) + if skip_confirm: + click.echo("Add your CODEFLASH_API_KEY as a GitHub secret before running this workflow.") + ph("cli-github-workflow-created") + return + # Show GitHub secrets setup panel (needed in both cases - PR created via API or local file) try: existing_api_key = get_codeflash_api_key() @@ -555,8 +565,11 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str: def detect_project_language_for_workflow(project_root: Path) -> str: """Detect the primary language of the project for workflow generation. - Returns: 'python', 'javascript', 'typescript', or 'java' + Returns: 'python', 'javascript', 'typescript', 'java', or 'go' """ + if (project_root / "go.mod").exists(): + return "go" + # Check for Java build tools first (pom.xml or build.gradle) if ( (project_root / "pom.xml").exists() @@ -693,9 +706,9 @@ def generate_dynamic_workflow_content( # Detect project language project_language = detect_project_language_for_workflow(Path.cwd()) - # For JavaScript/TypeScript and Java projects, use static template customization + # For JavaScript/TypeScript, Java, and Go projects, use static template customization # (AI-generated steps are currently Python-only) - if project_language in ("javascript", "typescript", "java"): + if project_language in ("javascript", "typescript", "java", "go"): return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode) # Python project - try AI-generated steps @@ -824,6 +837,9 @@ def customize_codeflash_yaml_content( if project_language in ("javascript", "typescript"): return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode) + if project_language == "go": + return _customize_go_workflow_content(optimize_yml_content, git_root, benchmark_mode) + # Python project (default) return _customize_python_workflow_content(optimize_yml_content, git_root, benchmark_mode) @@ -948,3 +964,38 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path) # Install dependencies command install_deps = get_java_dependency_installation_commands(build_tool) return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps) + + +def _customize_go_workflow_content(optimize_yml_content: str, git_root: Path, benchmark_mode: bool = False) -> str: + """Customize workflow content for Go projects.""" + from codeflash.cli_cmds.init_go import get_go_dependency_installation_commands, get_go_runtime_setup_steps + + project_root = Path.cwd() + + if project_root == git_root: + working_dir = "" + else: + rel_path = str(project_root.relative_to(git_root)) + working_dir = f"""defaults: + run: + working-directory: ./{rel_path}""" + + optimize_yml_content = optimize_yml_content.replace("Optimize new Python code", "Optimize new Go code") + optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir) + + python_setup = get_dependency_manager_installation_string(DependencyManager.PIP) + go_setup = get_go_runtime_setup_steps() + setup_runtime = f"""{python_setup} + {go_setup}""" + optimize_yml_content = optimize_yml_content.replace("{{ setup_runtime_environment }}", setup_runtime) + + install_deps = f"""| + python -m pip install --upgrade pip + pip install codeflash + {get_go_dependency_installation_commands()}""" + optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps) + + codeflash_cmd = "codeflash" + if benchmark_mode: + codeflash_cmd += " --benchmark" + return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd) diff --git a/codeflash/cli_cmds/init_auth.py b/codeflash/cli_cmds/init_auth.py index d6618daa4..7338e3e23 100644 --- a/codeflash/cli_cmds/init_auth.py +++ b/codeflash/cli_cmds/init_auth.py @@ -147,7 +147,17 @@ def enter_api_key_and_save_to_rc() -> None: os.environ["CODEFLASH_API_KEY"] = api_key +def _skip_github_app_installation(owner: str, repo: str) -> None: + click.echo( + f"Skipping Codeflash GitHub app installation for {owner}/{repo}.{LF}" + "Codeflash setup will continue, but PR creation will stay disabled until you install the app later." + f"{LF}In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}" + ) + + def install_github_app(git_remote: str) -> None: + from rich.prompt import Confirm + try: git_repo = git.Repo(search_parent_directories=True) except git.InvalidGitRepositoryError: @@ -167,6 +177,17 @@ def install_github_app(git_remote: str) -> None: else: try: + should_install = Confirm.ask( + "Do you want to install the Codeflash GitHub app now? You can skip this and continue setup, " + "but Codeflash won't be able to create PRs until the app is installed.", + default=True, + show_default=True, + console=console, + ) + if not should_install: + _skip_github_app_installation(owner, repo) + return + click.prompt( f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}" f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}" @@ -188,11 +209,7 @@ def install_github_app(git_remote: str) -> None: count = 2 while not is_github_app_installed_on_repo(owner, repo, suppress_errors=True): if count == 0: - click.echo( - f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" - f"You won't be able to create PRs with Codeflash until you install the app.{LF}" - f"In the meantime you can make local only optimizations by using the '--no-pr' flag with codeflash.{LF}" - ) + _skip_github_app_installation(owner, repo) break click.prompt( f"❌ It looks like the Codeflash GitHub App is not installed on the repository {owner}/{repo}.{LF}" @@ -207,3 +224,4 @@ def install_github_app(git_remote: str) -> None: except (KeyboardInterrupt, EOFError, click.exceptions.Abort): # leave empty line for the next prompt to be properly rendered click.echo() + _skip_github_app_installation(owner, repo) diff --git a/codeflash/cli_cmds/init_config.py b/codeflash/cli_cmds/init_config.py index bb10d6593..c6d316ab6 100644 --- a/codeflash/cli_cmds/init_config.py +++ b/codeflash/cli_cmds/init_config.py @@ -142,7 +142,7 @@ def is_valid_pyproject_toml(pyproject_toml_path: Union[str, Path]) -> tuple[bool return True, config, "" -def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]: +def should_modify_pyproject_toml(*, skip_confirm: bool = False) -> tuple[bool, dict[str, Any] | None]: """Check if the current directory contains a valid pyproject.toml file with codeflash config. If it does, ask the user if they want to re-configure it. @@ -160,6 +160,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]: # needs to be re-configured return True, None + if skip_confirm: + return False, config + return Confirm.ask( "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, diff --git a/codeflash/cli_cmds/init_go.py b/codeflash/cli_cmds/init_go.py index 032072231..9992cdce0 100644 --- a/codeflash/cli_cmds/init_go.py +++ b/codeflash/cli_cmds/init_go.py @@ -40,7 +40,7 @@ def _get_theme() -> Any: return CodeflashTheme() -def init_go_project() -> None: +def init_go_project(*, skip_confirm: bool = False, skip_api_key: bool = False) -> None: from codeflash.cli_cmds.github_workflow import install_github_actions from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key @@ -54,14 +54,14 @@ def init_go_project() -> None: console.print(lang_panel) console.print() - did_add_new_key = prompt_api_key() + did_add_new_key = False if skip_api_key else prompt_api_key() - setup_info = collect_go_setup_info() + setup_info = collect_go_setup_info(skip_confirm=skip_confirm) git_remote = setup_info.git_remote or "origin" install_github_app(git_remote) - install_github_actions(override_formatter_check=True) + install_github_actions(override_formatter_check=True, skip_confirm=skip_confirm) usage_table = Table(show_header=False, show_lines=False, border_style="dim") usage_table.add_column("Command", style="cyan") @@ -95,7 +95,7 @@ def init_go_project() -> None: sys.exit(0) -def collect_go_setup_info() -> GoSetupInfo: +def collect_go_setup_info(*, skip_confirm: bool = False) -> GoSetupInfo: from codeflash.cli_cmds.init_config import ask_for_telemetry @@ -129,6 +129,10 @@ def collect_go_setup_info() -> GoSetupInfo: console.print(detection_panel) console.print() + if skip_confirm: + git_remote = _get_git_remote_for_setup(skip_confirm=True) + return GoSetupInfo(git_remote=git_remote, disable_telemetry=False) + git_remote = _get_git_remote_for_setup() disable_telemetry = not ask_for_telemetry() @@ -136,7 +140,7 @@ def collect_go_setup_info() -> GoSetupInfo: return GoSetupInfo(git_remote=git_remote, disable_telemetry=disable_telemetry) -def _get_git_remote_for_setup() -> str: +def _get_git_remote_for_setup(*, skip_confirm: bool = False) -> str: try: repo = Repo(Path.cwd(), search_parent_directories=True) git_remotes = get_git_remotes(repo) @@ -145,6 +149,8 @@ def _get_git_remote_for_setup() -> str: if len(git_remotes) == 1: return git_remotes[0] + if skip_confirm: + return "origin" if "origin" in git_remotes else git_remotes[0] git_panel = Panel( Text( diff --git a/codeflash/cli_cmds/init_java.py b/codeflash/cli_cmds/init_java.py index eb01002fa..4d110751e 100644 --- a/codeflash/cli_cmds/init_java.py +++ b/codeflash/cli_cmds/init_java.py @@ -147,7 +147,7 @@ def detect_java_test_framework(project_root: Path) -> str: return "junit5" # Default to JUnit 5 -def init_java_project() -> None: +def init_java_project(*, skip_confirm: bool = False, skip_api_key: bool = False) -> None: """Initialize Codeflash for a Java project.""" from codeflash.cli_cmds.github_workflow import install_github_actions from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key @@ -162,15 +162,15 @@ def init_java_project() -> None: console.print(lang_panel) console.print() - did_add_new_key = prompt_api_key() + did_add_new_key = False if skip_api_key else prompt_api_key() - should_modify, _config = should_modify_java_config() + should_modify, _config = should_modify_java_config(skip_confirm=skip_confirm) # Default git remote git_remote = "origin" if should_modify: - setup_info = collect_java_setup_info() + setup_info = collect_java_setup_info(skip_confirm=skip_confirm) git_remote = setup_info.git_remote or "origin" configured = configure_java_project(setup_info) if not configured: @@ -178,7 +178,7 @@ def init_java_project() -> None: install_github_app(git_remote) - install_github_actions(override_formatter_check=True) + install_github_actions(override_formatter_check=True, skip_confirm=skip_confirm) # Show completion message usage_table = Table(show_header=False, show_lines=False, border_style="dim") @@ -213,7 +213,7 @@ def init_java_project() -> None: sys.exit(0) -def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]: +def should_modify_java_config(*, skip_confirm: bool = False) -> tuple[bool, dict[str, Any] | None]: """Check if the project already has Codeflash config.""" from rich.prompt import Confirm @@ -226,6 +226,8 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]: strategy = get_config_strategy(project_root) existing = strategy.read_codeflash_properties(project_root) if existing: + if skip_confirm: + return False, None return Confirm.ask( "A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True ), None @@ -235,7 +237,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]: return True, None -def collect_java_setup_info() -> JavaSetupInfo: +def collect_java_setup_info(*, skip_confirm: bool = False) -> JavaSetupInfo: """Collect setup information for Java projects.""" from rich.prompt import Confirm @@ -271,6 +273,10 @@ def collect_java_setup_info() -> JavaSetupInfo: console.print(detection_panel) console.print() + if skip_confirm: + git_remote = _get_git_remote_for_setup(skip_confirm=True) + return JavaSetupInfo(git_remote=git_remote, disable_telemetry=False) + # Ask if user wants to change any settings module_root_override = None test_root_override = None @@ -382,7 +388,7 @@ def _prompt_custom_directory(dir_type: str) -> str: console.print() -def _get_git_remote_for_setup() -> str: +def _get_git_remote_for_setup(*, skip_confirm: bool = False) -> str: """Get git remote for project setup.""" try: repo = Repo(Path.cwd(), search_parent_directories=True) @@ -392,6 +398,8 @@ def _get_git_remote_for_setup() -> str: if len(git_remotes) == 1: return git_remotes[0] + if skip_confirm: + return "origin" if "origin" in git_remotes else git_remotes[0] git_panel = Panel( Text( diff --git a/codeflash/cli_cmds/init_javascript.py b/codeflash/cli_cmds/init_javascript.py index 20f76d249..6e7f407f3 100644 --- a/codeflash/cli_cmds/init_javascript.py +++ b/codeflash/cli_cmds/init_javascript.py @@ -28,6 +28,7 @@ from codeflash.code_utils.compat import LF from codeflash.code_utils.git_utils import get_git_remotes from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell +from codeflash.languages.javascript.command_utils import resolve_node_command from codeflash.telemetry.posthog_cf import ph @@ -201,22 +202,22 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr pkg_manager = determine_js_package_manager(project_root) if pkg_manager == JsPackageManager.PNPM: - cmd = ["pnpm", "add", package] + cmd = [resolve_node_command("pnpm"), "add", package] if dev: cmd.append("--save-dev") return cmd if pkg_manager == JsPackageManager.YARN: - cmd = ["yarn", "add", package] + cmd = [resolve_node_command("yarn"), "add", package] if dev: cmd.append("--dev") return cmd if pkg_manager == JsPackageManager.BUN: - cmd = ["bun", "add", package] + cmd = [resolve_node_command("bun"), "add", package] if dev: cmd.append("--dev") return cmd # Default to npm - cmd = ["npm", "install", package] + cmd = [resolve_node_command("npm"), "install", package] if dev: cmd.append("--save-dev") return cmd @@ -257,7 +258,7 @@ def init_js_project(language: ProjectLanguage, *, skip_confirm: bool = False, sk install_github_app(git_remote) - install_github_actions(override_formatter_check=True) + install_github_actions(override_formatter_check=True, skip_confirm=skip_confirm) # Show completion message usage_table = Table(show_header=False, show_lines=False, border_style="dim") diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 87960fa3f..a94e6cbf3 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -21,7 +21,8 @@ def _try_parse_go_config() -> tuple[dict[str, Any], Path] | None: "language": "go", "module_root": module_root, "tests_root": module_root, - "pytest_cmd": "pytest", + "pytest_cmd": "go test ./...", + "test_framework": "go-test", "git_remote": "origin", "disable_telemetry": False, "disable_imports_sorting": False, diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 4bfd96104..9e8351ce9 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -60,6 +60,10 @@ def apply_formatter_cmds( for command in cmds: formatter_cmd_list = shlex.split(command, posix=os.name != "nt") formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list] + if str(lang_support.language) in ("javascript", "typescript"): + from codeflash.languages.javascript.command_utils import resolve_node_command_list + + formatter_cmd_list = resolve_node_command_list(formatter_cmd_list) try: result = subprocess.run(formatter_cmd_list, capture_output=True, check=False) if result.returncode == 0: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index fdac43c25..46299b64f 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import contextlib import os import random import warnings @@ -773,8 +772,10 @@ def was_function_previously_optimized( # already_optimized_count = 0 owner = None repo = None - with contextlib.suppress(git.exc.InvalidGitRepositoryError): - owner, repo = get_repo_owner_and_name() + try: + owner, repo = get_repo_owner_and_name(git_remote=getattr(args, "git_remote", "origin")) + except Exception as exc: + logger.debug("Skipping previous optimization lookup because repository metadata is unavailable: %s", exc) pr_number = get_pr_number() diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index aaa7c8e3f..963310b6a 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -42,10 +42,22 @@ @register_language class GoSupport(LanguageSupport): def __init__(self) -> None: - self._analyzer = GoAnalyzer() + self._analyzer: GoAnalyzer | None = None self._go_version: str | None = None self._go_version_detected = False + def _get_analyzer(self) -> GoAnalyzer: + if self._analyzer is None: + try: + self._analyzer = GoAnalyzer() + except ModuleNotFoundError as exc: + msg = ( + "Go support requires the tree-sitter Go parser. " + "Reinstall or sync Codeflash dependencies to enable Go optimization." + ) + raise ModuleNotFoundError(msg) from exc + return self._analyzer + @property def language(self) -> Language: return Language.GO @@ -94,7 +106,7 @@ def function_optimizer_class(self) -> type: def discover_functions( self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None ) -> list[FunctionToOptimize]: - return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer) + return discover_functions_from_source(source, file_path, filter_criteria, self._get_analyzer()) def discover_tests( self, test_root: Path, source_functions: Sequence[FunctionToOptimize] @@ -102,7 +114,7 @@ def discover_tests( return _discover_tests(test_root, source_functions) def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: - return self._analyzer.validate_syntax(source) + return self._get_analyzer().validate_syntax(source) def parse_test_xml( self, test_xml_file_path: Path, test_files: Any, test_config: Any, run_result: Any = None @@ -112,14 +124,14 @@ def parse_test_xml( return parse_go_test_output(test_xml_file_path, test_files, test_config, run_result) def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: - return _extract_context(function, project_root, module_root, self._analyzer) + return _extract_context(function, project_root, module_root, self._get_analyzer()) def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: try: source = function.file_path.read_text(encoding="utf-8") except Exception: return [] - return _find_helpers(source, function, self._analyzer) + return _find_helpers(source, function, self._get_analyzer()) def find_references( self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 100 @@ -127,7 +139,7 @@ def find_references( return [] def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: - return _replace_func(source, function, new_source, self._analyzer) + return _replace_func(source, function, new_source, self._get_analyzer()) def format_code(self, source: str, file_path: Path | None = None) -> str: return format_go_code(source, file_path) @@ -136,7 +148,7 @@ def normalize_code(self, source: str) -> str: return normalize_go_code(source) def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: - return _add_globals(optimized_code, original_source, self._analyzer) + return _add_globals(optimized_code, original_source, self._get_analyzer()) def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str: return str(source_file) @@ -146,7 +158,7 @@ def prepare_module( ) -> tuple[dict[Path, Any], None] | None: from codeflash.models.models import ValidCode - if not self._analyzer.validate_syntax(module_code): + if not self._get_analyzer().validate_syntax(module_code): return None validated: dict[Path, ValidCode] = { module_path: ValidCode(source_code=module_code, normalized_code=normalize_go_code(module_code)) @@ -298,7 +310,7 @@ def add_runtime_comments( return test_source def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: - return _remove_tests(test_source, functions_to_remove, self._analyzer) + return _remove_tests(test_source, functions_to_remove, self._get_analyzer()) def add_runtime_comments_to_generated_tests( self, diff --git a/codeflash/languages/java/build_config_strategy.py b/codeflash/languages/java/build_config_strategy.py index 793c75162..cfca00bd2 100644 --- a/codeflash/languages/java/build_config_strategy.py +++ b/codeflash/languages/java/build_config_strategy.py @@ -346,11 +346,16 @@ def parse_java_project_config(project_root: Path) -> dict[str, Any] | None: Returns None if no Java build tool is detected. """ from codeflash.languages.java.build_tools import BuildTool, detect_build_tool, find_source_root, find_test_root + from codeflash.languages.java.config import detect_java_project build_tool = detect_build_tool(project_root) if build_tool == BuildTool.UNKNOWN: return None + detected_project = detect_java_project(project_root) + test_framework = detected_project.test_framework if detected_project is not None else "junit5" + test_command = "mvn test" if build_tool == BuildTool.MAVEN else "./gradlew test" + try: strategy = get_config_strategy(project_root) user_config = strategy.read_codeflash_properties(project_root) @@ -379,7 +384,8 @@ def parse_java_project_config(project_root: Path) -> dict[str, Any] | None: if "testsRoot" in user_config else (test_root or (default_test if default_test.is_dir() else project_root)) ), - "pytest_cmd": "pytest", + "pytest_cmd": test_command, + "test_framework": test_framework, "git_remote": user_config.get("gitRemote", "origin"), "disable_telemetry": user_config.get("disableTelemetry", "false").lower() == "true", "disable_imports_sorting": False, diff --git a/codeflash/languages/javascript/command_utils.py b/codeflash/languages/javascript/command_utils.py new file mode 100644 index 000000000..00b338448 --- /dev/null +++ b/codeflash/languages/javascript/command_utils.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import shutil +import sys +from pathlib import Path + + +def resolve_node_command(command: str) -> str: + """Resolve a Node ecosystem executable to a subprocess-safe path. + + On Windows, tools such as npm, npx, pnpm, yarn, and bun are often exposed + as .cmd wrappers. Passing the bare command name to subprocess can fail even + when the command is on PATH, so resolve it to the concrete executable path + first. + """ + candidates = [command] + if sys.platform == "win32" and not Path(command).suffix: + candidates.extend([f"{command}.cmd", f"{command}.bat", f"{command}.exe"]) + + for candidate in candidates: + resolved = shutil.which(candidate) + if resolved: + return resolved + + return command + + +def resolve_node_command_list(command: list[str]) -> list[str]: + if not command: + return command + return [resolve_node_command(command[0]), *command[1:]] diff --git a/codeflash/languages/javascript/mocha_runner.py b/codeflash/languages/javascript/mocha_runner.py index 59a6f7067..14e60a70f 100644 --- a/codeflash/languages/javascript/mocha_runner.py +++ b/codeflash/languages/javascript/mocha_runner.py @@ -19,6 +19,7 @@ from codeflash.cli_cmds.init_javascript import get_package_install_command from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash.languages.javascript.command_utils import resolve_node_command_list if TYPE_CHECKING: from codeflash.models.models import TestFiles @@ -263,7 +264,7 @@ def _build_mocha_behavioral_command( Command list for subprocess execution. """ - cmd = ["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"] + cmd = resolve_node_command_list(["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"]) if timeout: cmd.extend(["--timeout", str(timeout * 1000)]) @@ -289,7 +290,7 @@ def _build_mocha_benchmarking_command( Command list for subprocess execution. """ - cmd = ["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"] + cmd = resolve_node_command_list(["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"]) if timeout: cmd.extend(["--timeout", str(timeout * 1000)]) @@ -315,7 +316,7 @@ def _build_mocha_line_profile_command( Command list for subprocess execution. """ - cmd = ["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"] + cmd = resolve_node_command_list(["npx", "mocha", "--reporter", "json", "--jobs", "1", "--exit"]) if timeout: cmd.extend(["--timeout", str(timeout * 1000)]) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index d96fcd80c..4404470ad 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -24,6 +24,7 @@ TestInfo, TestResult, ) +from codeflash.languages.javascript.command_utils import resolve_node_command, resolve_node_command_list from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage, get_analyzer_for_file from codeflash.languages.registry import register_language from codeflash.models.models import FunctionParent @@ -1566,7 +1567,7 @@ def format_code(self, source: str, file_path: Path | None = None) -> str: stdin_filepath = str(file_path.name) if file_path else f"file{self.default_file_extension}" result = subprocess.run( - ["npx", "prettier", "--stdin-filepath", stdin_filepath], + resolve_node_command_list(["npx", "prettier", "--stdin-filepath", stdin_filepath]), check=False, input=source, capture_output=True, @@ -1606,15 +1607,17 @@ def run_tests( # Build Jest command test_pattern = "|".join(str(f) for f in test_files) - cmd = [ - "npx", - "jest", - "--reporters=default", - "--reporters=jest-junit", - f"--testPathPattern={test_pattern}", - "--runInBand", # Sequential for deterministic timing - "--forceExit", - ] + cmd = resolve_node_command_list( + [ + "npx", + "jest", + "--reporters=default", + "--reporters=jest-junit", + f"--testPathPattern={test_pattern}", + "--runInBand", # Sequential for deterministic timing + "--forceExit", + ] + ) test_env = env.copy() test_env["JEST_JUNIT_OUTPUT_FILE"] = str(junit_xml) @@ -2314,10 +2317,12 @@ def verify_requirements(self, project_root: Path, test_framework: str = "jest") """ errors: list[SetupError] = [] + node_cmd = resolve_node_command("node") + npm_cmd = resolve_node_command("npm") # Check Node.js try: - result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10) + result = subprocess.run([node_cmd, "--version"], check=False, capture_output=True, text=True, timeout=10) if result.returncode != 0: errors.append( SetupError( @@ -2336,7 +2341,7 @@ def verify_requirements(self, project_root: Path, test_framework: str = "jest") # Check npm try: - result = subprocess.run(["npm", "--version"], check=False, capture_output=True, text=True, timeout=10) + result = subprocess.run([npm_cmd, "--version"], check=False, capture_output=True, text=True, timeout=10) if result.returncode != 0: errors.append( SetupError("npm is not available. Please ensure npm is installed with Node.js.", should_abort=True) @@ -2378,7 +2383,9 @@ def verify_requirements(self, project_root: Path, test_framework: str = "jest") def _detect_node_version(self) -> None: """Detect and cache the Node.js runtime version.""" try: - result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10) + result = subprocess.run( + [resolve_node_command("node"), "--version"], check=False, capture_output=True, text=True, timeout=10 + ) if result.returncode == 0 and result.stdout.strip(): self._language_version = result.stdout.strip().lstrip("v") except Exception: @@ -2407,7 +2414,7 @@ def ensure_runtime_environment(self, project_root: Path) -> bool: try: result = subprocess.run( - ["npm", "install", "--save-dev", "codeflash"], + resolve_node_command_list(["npm", "install", "--save-dev", "codeflash"]), check=False, cwd=project_root, capture_output=True, diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index 70e59d7f8..2ba4ea0ea 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -18,6 +18,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.config_consts import STABILITY_CENTER_TOLERANCE, STABILITY_SPREAD_TOLERANCE from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash.languages.javascript.command_utils import resolve_node_command_list if TYPE_CHECKING: from codeflash.models.models import TestFiles @@ -794,14 +795,16 @@ def run_jest_behavioral_tests( coverage_json_path = coverage_dir / "coverage-final.json" if enable_coverage else None # Build Jest command - jest_cmd = [ - "npx", - "jest", - "--reporters=default", - f"--reporters={CODEFLASH_JEST_REPORTER}", - "--runInBand", # Run tests serially for consistent timing - "--forceExit", - ] + jest_cmd = resolve_node_command_list( + [ + "npx", + "jest", + "--reporters=default", + f"--reporters={CODEFLASH_JEST_REPORTER}", + "--runInBand", # Run tests serially for consistent timing + "--forceExit", + ] + ) # Add Jest config if found - needed for TypeScript transformation # Uses codeflash-compatible config if project has bundler moduleResolution @@ -1048,15 +1051,17 @@ def run_jest_benchmarking_tests( logger.debug(f"Jest {jest_major_version} detected - using loop-runner for batched looping") # Build Jest command for performance tests - jest_cmd = [ - "npx", - "jest", - "--reporters=default", - f"--reporters={CODEFLASH_JEST_REPORTER}", - "--runInBand", # Ensure serial execution - "--forceExit", - "--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping - ] + jest_cmd = resolve_node_command_list( + [ + "npx", + "jest", + "--reporters=default", + f"--reporters={CODEFLASH_JEST_REPORTER}", + "--runInBand", # Ensure serial execution + "--forceExit", + "--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping + ] + ) # Add Jest config if found - needed for TypeScript transformation # Uses codeflash-compatible config if project has bundler moduleResolution @@ -1224,14 +1229,16 @@ def run_jest_line_profile_tests( _ensure_runtime_files(effective_cwd) # Build Jest command for line profiling - simple run without benchmarking loops - jest_cmd = [ - "npx", - "jest", - "--reporters=default", - f"--reporters={CODEFLASH_JEST_REPORTER}", - "--runInBand", # Run tests serially for consistent line profiling - "--forceExit", - ] + jest_cmd = resolve_node_command_list( + [ + "npx", + "jest", + "--reporters=default", + f"--reporters={CODEFLASH_JEST_REPORTER}", + "--runInBand", # Run tests serially for consistent line profiling + "--forceExit", + ] + ) # Add Jest config if found - needed for TypeScript transformation # Uses codeflash-compatible config if project has bundler moduleResolution diff --git a/codeflash/languages/javascript/tracer_runner.py b/codeflash/languages/javascript/tracer_runner.py index 8b4ca5adb..5fcc118d7 100644 --- a/codeflash/languages/javascript/tracer_runner.py +++ b/codeflash/languages/javascript/tracer_runner.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional +from codeflash.languages.javascript.command_utils import resolve_node_command_list + if TYPE_CHECKING: from argparse import Namespace @@ -34,7 +36,9 @@ def find_trace_runner() -> Optional[Path]: return local_path try: - result = subprocess.run(["npm", "root", "-g"], capture_output=True, text=True, check=True) + result = subprocess.run( + resolve_node_command_list(["npm", "root", "-g"]), capture_output=True, text=True, check=True + ) global_modules = Path(result.stdout.strip()) global_path = global_modules / "codeflash" / "runtime" / "trace-runner.js" if global_path.exists(): diff --git a/codeflash/languages/javascript/vitest_runner.py b/codeflash/languages/javascript/vitest_runner.py index dcf3a2ed3..716a87853 100644 --- a/codeflash/languages/javascript/vitest_runner.py +++ b/codeflash/languages/javascript/vitest_runner.py @@ -17,6 +17,7 @@ from codeflash.cli_cmds.init_javascript import get_package_install_command from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args +from codeflash.languages.javascript.command_utils import resolve_node_command_list if TYPE_CHECKING: from codeflash.models.models import TestFiles @@ -317,15 +318,17 @@ def _build_vitest_behavioral_command( Command list for subprocess execution. """ - cmd = [ - "npx", - "vitest", - "run", # Single execution (not watch mode) - "--reporter=default", - "--reporter=junit", - "--no-file-parallelism", # Serial execution for deterministic timing - "--pool=forks", # Use child processes so timing markers flow to parent stdout - ] + cmd = resolve_node_command_list( + [ + "npx", + "vitest", + "run", # Single execution (not watch mode) + "--reporter=default", + "--reporter=junit", + "--no-file-parallelism", # Serial execution for deterministic timing + "--pool=forks", # Use child processes so timing markers flow to parent stdout + ] + ) # For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts), # we need to create a custom config that allows all test patterns. @@ -379,15 +382,17 @@ def _build_vitest_benchmarking_command( Command list for subprocess execution. """ - cmd = [ - "npx", - "vitest", - "run", # Single execution (not watch mode) - "--reporter=default", - "--reporter=junit", - "--no-file-parallelism", # Serial execution for consistent benchmarking - "--pool=forks", # Use child processes so timing markers flow to parent stdout - ] + cmd = resolve_node_command_list( + [ + "npx", + "vitest", + "run", # Single execution (not watch mode) + "--reporter=default", + "--reporter=junit", + "--no-file-parallelism", # Serial execution for consistent benchmarking + "--pool=forks", # Use child processes so timing markers flow to parent stdout + ] + ) # Use codeflash vitest config to override restrictive include patterns if project_root: @@ -805,15 +810,17 @@ def run_vitest_line_profile_tests( _ensure_runtime_files(effective_cwd) # Build Vitest command for line profiling - simple run without benchmarking loops - vitest_cmd = [ - "npx", - "vitest", - "run", - "--reporter=default", - "--reporter=junit", - "--no-file-parallelism", # Serial execution for consistent line profiling - "--pool=forks", # Use child processes so timing markers flow to parent stdout - ] + vitest_cmd = resolve_node_command_list( + [ + "npx", + "vitest", + "run", + "--reporter=default", + "--reporter=junit", + "--no-file-parallelism", # Serial execution for consistent line profiling + "--pool=forks", # Use child processes so timing markers flow to parent stdout + ] + ) # Use codeflash vitest config to override restrictive include patterns if effective_cwd: diff --git a/codeflash/main.py b/codeflash/main.py index 21ed2c5f3..9b85f0db2 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -31,7 +31,8 @@ def main() -> None: from codeflash.cli_cmds.cli import parse_args - if "--help" in sys.argv[1:] or "-h" in sys.argv[1:]: + help_requested = "--help" in sys.argv[1:] or "-h" in sys.argv[1:] + if help_requested: print_codeflash_banner() args = parse_args() @@ -51,7 +52,8 @@ def main() -> None: # Compare command only needs its own imports if args.command == "compare": - print_codeflash_banner() + if not help_requested: + print_codeflash_banner() from codeflash.cli_cmds.cmd_compare import run_compare run_compare(args) @@ -68,7 +70,8 @@ def main() -> None: from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry - print_codeflash_banner() + if not help_requested: + print_codeflash_banner() check_for_newer_minor_version() if args.command: @@ -82,11 +85,13 @@ def main() -> None: if args.command == "init": from codeflash.cli_cmds.cmd_init import init_codeflash - init_codeflash() + init_codeflash( + skip_confirm=getattr(args, "yes", False), skip_api_key=bool(os.environ.get("CODEFLASH_API_KEY")) + ) elif args.command == "init-actions": from codeflash.cli_cmds.github_workflow import install_github_actions - install_github_actions() + install_github_actions(skip_confirm=getattr(args, "yes", False)) elif args.command == "vscode-install": from codeflash.cli_cmds.extension import install_vscode_extension diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 917a413e1..3bde8eb4d 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -6,7 +6,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email @@ -55,6 +55,30 @@ def _extract_java_package_from_path(file_path: Path) -> str | None: return None +def _install_optimizer_signal_handlers( + signal_module: object, signal_handler: Callable[[int, object], None] +) -> dict[object, object]: + original_handlers: dict[object, object] = {} + for signal_name in ("SIGTERM", "SIGHUP", "SIGQUIT", "SIGPIPE"): + signum = getattr(signal_module, signal_name, None) + if signum is None: + continue + try: + original_handlers[signum] = signal_module.getsignal(signum) + signal_module.signal(signum, signal_handler) + except (AttributeError, OSError, RuntimeError, ValueError) as exc: + logger.debug("Skipping unsupported signal %s: %s", signal_name, exc) + return original_handlers + + +def _restore_optimizer_signal_handlers(signal_module: object, original_handlers: dict[object, object]) -> None: + for signum, original_handler in original_handlers.items(): + try: + signal_module.signal(signum, original_handler) + except (AttributeError, OSError, RuntimeError, ValueError) as exc: + logger.debug("Failed to restore signal handler for %s: %s", signum, exc) + + class Optimizer: def __init__(self, args: Namespace) -> None: self.args = args @@ -873,10 +897,6 @@ def run_with_args(args: Namespace) -> None: cleanup_stale_worktrees() optimizer = None - original_sigterm = signal.getsignal(signal.SIGTERM) - original_sighup = signal.getsignal(signal.SIGHUP) - original_sigquit = signal.getsignal(signal.SIGQUIT) - original_sigpipe = signal.getsignal(signal.SIGPIPE) def cleanup_worktree_on_exit() -> None: if optimizer and optimizer.current_worktree: @@ -889,10 +909,7 @@ def signal_handler(signum: int, frame: object) -> None: raise SystemExit(128 + signum) atexit.register(cleanup_worktree_on_exit) - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGHUP, signal_handler) - signal.signal(signal.SIGQUIT, signal_handler) - signal.signal(signal.SIGPIPE, signal_handler) + original_signal_handlers = _install_optimizer_signal_handlers(signal, signal_handler) try: optimizer = Optimizer(args) @@ -905,7 +922,4 @@ def signal_handler(signum: int, frame: object) -> None: raise SystemExit from None finally: atexit.unregister(cleanup_worktree_on_exit) - signal.signal(signal.SIGTERM, original_sigterm) - signal.signal(signal.SIGHUP, original_sighup) - signal.signal(signal.SIGQUIT, original_sigquit) - signal.signal(signal.SIGPIPE, original_sigpipe) + _restore_optimizer_signal_handlers(signal, original_signal_handlers) diff --git a/codeflash/setup/config_writer.py b/codeflash/setup/config_writer.py index b0598dcb0..aee7211b2 100644 --- a/codeflash/setup/config_writer.py +++ b/codeflash/setup/config_writer.py @@ -83,7 +83,7 @@ def _write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> tuple[ doc["tool"]["codeflash"] = codeflash_table # Write back - with pyproject_path.open("w", encoding="utf8") as f: + with pyproject_path.open("w", encoding="utf8", newline="") as f: f.write(tomlkit.dumps(doc)) return True, f"Config saved to {pyproject_path}" @@ -207,7 +207,7 @@ def _remove_from_pyproject(project_root: Path) -> tuple[bool, str]: if "tool" in doc and "codeflash" in doc["tool"]: del doc["tool"]["codeflash"] - with pyproject_path.open("w", encoding="utf8") as f: + with pyproject_path.open("w", encoding="utf8", newline="") as f: f.write(tomlkit.dumps(doc)) return True, "Removed [tool.codeflash] section from pyproject.toml" diff --git a/codeflash/telemetry/posthog_cf.py b/codeflash/telemetry/posthog_cf.py index 2d336e3c2..bb0dca8f0 100644 --- a/codeflash/telemetry/posthog_cf.py +++ b/codeflash/telemetry/posthog_cf.py @@ -35,18 +35,21 @@ def ph(event: str, properties: dict[str, Any] | None = None) -> None: if _posthog is None: return - from codeflash.api.cfapi import get_user_id - from codeflash.lsp.helpers import is_subagent_mode - from codeflash.version import __version__ + from codeflash.cli_cmds.console import logger - properties = properties or {} - properties.update({"cli_version": __version__, "subagent": is_subagent_mode()}) + try: + from codeflash.api.cfapi import get_user_id + from codeflash.lsp.helpers import is_subagent_mode + from codeflash.version import __version__ - user_id = get_user_id() + properties = properties or {} + properties.update({"cli_version": __version__, "subagent": is_subagent_mode()}) - if user_id: - _posthog.capture(distinct_id=user_id, event=event, properties=properties) - else: - from codeflash.cli_cmds.console import logger + user_id = get_user_id(suppress_errors=True) - logger.debug("Failed to log event to PostHog: User ID could not be retrieved.") + if user_id: + _posthog.capture(distinct_id=user_id, event=event, properties=properties) + else: + logger.debug("Failed to log event to PostHog: User ID could not be retrieved.") + except (Exception, SystemExit) as exc: + logger.debug("Failed to log event to PostHog: %s", exc) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 834c7842d..b362160bf 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -37,6 +37,42 @@ logger = logging.getLogger(__name__) +def _build_tracer_parser(*, prog: str | None = None) -> ArgumentParser: + parser = ArgumentParser(allow_abbrev=False, prog=prog) + parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") + parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) + parser.add_argument( + "--max-function-count", + help="Maximum number of inputs for one function to include in the trace.", + type=int, + default=256, + ) + parser.add_argument( + "--tracer-timeout", + help="Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing stops and " + "normal execution continues.", + type=float, + default=None, + ) + parser.add_argument("-m", action="store_true", dest="module", help="Trace a library module", default=False) + parser.add_argument( + "--codeflash-config", + help="Optional path to the project's pyproject.toml file " + "with the codeflash config. Will be auto-discovered if not specified.", + default=None, + ) + parser.add_argument("--trace-only", action="store_true", help="Trace and create replay tests only, don't optimize") + parser.add_argument( + "--limit", type=int, default=None, help="Limit the number of test files to process (for -m pytest mode)" + ) + parser.add_argument( + "--language", + help="Language to trace (python, javascript, typescript). Auto-detected if not specified.", + default=None, + ) + return parser + + def _detect_non_python_language(args: Namespace | None) -> Language | None: """Detect if the project uses a non-Python language from --file or config. @@ -96,6 +132,11 @@ def _detect_non_python_language(args: Namespace | None) -> Language | None: def main(args: Namespace | None = None) -> ArgumentParser: + if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + parser = _build_tracer_parser(prog="codeflash optimize") + parser.print_help() + return parser + # For non-Python languages, detect early and route to the appropriate handler. # Java, JavaScript, and TypeScript use their own test runners (Maven/JUnit, Jest) # and should not go through Python tracing. @@ -122,38 +163,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: if detected_language == Language.JAVA: return _run_java_tracer(args) - parser = ArgumentParser(allow_abbrev=False) - parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") - parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) - parser.add_argument( - "--max-function-count", - help="Maximum number of inputs for one function to include in the trace.", - type=int, - default=256, - ) - parser.add_argument( - "--tracer-timeout", - help="Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing stops and " - "normal execution continues.", - type=float, - default=None, - ) - parser.add_argument("-m", action="store_true", dest="module", help="Trace a library module", default=False) - parser.add_argument( - "--codeflash-config", - help="Optional path to the project's pyproject.toml file " - "with the codeflash config. Will be auto-discovered if not specified.", - default=None, - ) - parser.add_argument("--trace-only", action="store_true", help="Trace and create replay tests only, don't optimize") - parser.add_argument( - "--limit", type=int, default=None, help="Limit the number of test files to process (for -m pytest mode)" - ) - parser.add_argument( - "--language", - help="Language to trace (python, javascript, typescript). Auto-detected if not specified.", - default=None, - ) + parser = _build_tracer_parser() if args is not None: parsed_args = args diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 9751ebc11..d5a4d8bb5 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -61,10 +61,11 @@ def generate_tests( source_file_abs = source_file.resolve().with_suffix("") test_dir_abs = test_path.resolve().parent - # Compute relative path from test directory to source file - rel_import_path = os.path.relpath(str(source_file_abs), str(test_dir_abs)) + # Compute relative path from test directory to source file. + # JavaScript import specifiers must always use forward slashes, even on Windows. + rel_import_path = os.path.relpath(str(source_file_abs), str(test_dir_abs)).replace("\\", "/") # Ensure path starts with ./ or ../ for JavaScript/TypeScript imports - if not rel_import_path.startswith("../"): + if not rel_import_path.startswith(("../", "./")): rel_import_path = f"./{rel_import_path}" # ESM requires explicit file extensions in import specifiers. # TypeScript ESM also uses .js extensions (TS resolves .js → .ts). diff --git a/tests/languages/javascript/test_vitest_runner.py b/tests/languages/javascript/test_vitest_runner.py index a1ff4b728..d5c502b94 100644 --- a/tests/languages/javascript/test_vitest_runner.py +++ b/tests/languages/javascript/test_vitest_runner.py @@ -16,6 +16,10 @@ ) +def command_name(command: str) -> str: + return Path(command).stem.lower() + + class TestFindVitestProjectRoot: """Tests for _find_vitest_project_root function.""" @@ -96,9 +100,8 @@ def test_basic_command_structure(self) -> None: cmd = _build_vitest_behavioral_command([test_file], timeout=60) - assert cmd[0] == "npx" - assert cmd[1] == "vitest" - assert cmd[2] == "run" + assert command_name(cmd[0]) == "npx" + assert cmd[1:3] == ["vitest", "run"] def test_includes_reporter_flags(self) -> None: """Should include reporter flags for JUnit output.""" @@ -174,9 +177,8 @@ def test_basic_command_structure(self) -> None: cmd = _build_vitest_benchmarking_command([test_file], timeout=60) - assert cmd[0] == "npx" - assert cmd[1] == "vitest" - assert cmd[2] == "run" + assert command_name(cmd[0]) == "npx" + assert cmd[1:3] == ["vitest", "run"] def test_includes_serial_execution(self) -> None: """Should include serial execution for consistent benchmarking.""" @@ -202,7 +204,8 @@ def test_vitest_uses_run_subcommand(self) -> None: vitest_cmd = _build_vitest_behavioral_command([test_file], timeout=60) - assert vitest_cmd[0:3] == ["npx", "vitest", "run"] + assert command_name(vitest_cmd[0]) == "npx" + assert vitest_cmd[1:3] == ["vitest", "run"] def test_vitest_uses_hyphenated_timeout(self) -> None: """Vitest uses --test-timeout, Jest uses --testTimeout (camelCase).""" diff --git a/tests/test_cfapi.py b/tests/test_cfapi.py new file mode 100644 index 000000000..663227ff0 --- /dev/null +++ b/tests/test_cfapi.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from codeflash.api.cfapi import get_user_id + + +def test_get_user_id_suppresses_invalid_key_exit(monkeypatch) -> None: + response = Mock() + response.status_code = 403 + response.reason = "Forbidden" + + exit_with_message = Mock(side_effect=AssertionError("exit_with_message should not be called")) + + monkeypatch.setattr("codeflash.api.cfapi.ensure_codeflash_api_key", Mock(return_value=True)) + monkeypatch.setattr("codeflash.api.cfapi.make_cfapi_request", Mock(return_value=response)) + monkeypatch.setattr("codeflash.api.cfapi.exit_with_message", exit_with_message) + + get_user_id.cache_clear() + try: + assert get_user_id(suppress_errors=True) is None + finally: + get_user_id.cache_clear() + + exit_with_message.assert_not_called() diff --git a/tests/test_cli_output_encoding.py b/tests/test_cli_output_encoding.py new file mode 100644 index 000000000..20b86aa1b --- /dev/null +++ b/tests/test_cli_output_encoding.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +MAIN_SNIPPET = ( + f"import sys; sys.path.insert(0, {str(REPO_ROOT)!r}); " + "from codeflash.main import main; " + "main()" +) + + +def run_codeflash(args: list[str], *, cwd: Path) -> subprocess.CompletedProcess[str]: + env = os.environ.copy() + env["PYTHONIOENCODING"] = "cp1252:strict" + env.pop("PYTHONUTF8", None) + return subprocess.run( + [sys.executable, "-c", MAIN_SNIPPET, *args], + cwd=cwd, + capture_output=True, + text=True, + encoding="cp1252", + errors="replace", + env=env, + ) + + +def test_optimize_help_does_not_crash_with_cp1252_output() -> None: + result = run_codeflash(["optimize", "--help"], cwd=REPO_ROOT) + + assert result.returncode == 0, result.stderr + assert "UnicodeEncodeError" not in result.stderr + assert "--trace-only" in result.stdout + + +def test_show_config_does_not_crash_with_cp1252_output(tmp_path: Path) -> None: + project_root = tmp_path / "python_project" + project_root.mkdir() + (project_root / "demo_pkg").mkdir() + (project_root / "tests").mkdir() + (project_root / "demo_pkg" / "__init__.py").write_text("", encoding="utf-8") + (project_root / "tests" / "test_demo.py").write_text("def test_placeholder():\n assert True\n", encoding="utf-8") + (project_root / "pyproject.toml").write_text( + "[project]\nname = 'demo-project'\nversion = '0.1.0'\n", + encoding="utf-8", + ) + + subprocess.run( + ["git", "init"], + cwd=project_root, + capture_output=True, + text=True, + check=True, + ) + + result = run_codeflash(["--show-config"], cwd=project_root) + + assert result.returncode == 0, result.stderr + assert "UnicodeEncodeError" not in result.stderr + assert "Codeflash Configuration" in result.stdout + assert "Auto-detected (not saved)" in result.stdout diff --git a/tests/test_cli_show_config_zero_config.py b/tests/test_cli_show_config_zero_config.py new file mode 100644 index 000000000..7c665b20b --- /dev/null +++ b/tests/test_cli_show_config_zero_config.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.cli_cmds.cli import _handle_show_config +from codeflash.cli_cmds.console import console + + +def _capture_show_config(monkeypatch, project_root: Path) -> str: + monkeypatch.chdir(project_root) + with console.capture() as capture: + _handle_show_config() + return capture.get() + + +def test_show_config_reports_zero_config_java_projects(monkeypatch, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text( + """ + + 4.0.0 + com.example + demo + 1.0.0 + + + org.junit.jupiter + junit-jupiter + 5.10.0 + test + + + +""", + encoding="utf-8", + ) + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + + output = _capture_show_config(monkeypatch, tmp_path) + + assert "Codeflash Configuration" in output + assert "Auto-detected (zero-config)" in output + assert "Config source:" in output + assert "Config file:" not in output + assert "junit5" in output + + +def test_show_config_reports_zero_config_go_projects(monkeypatch, tmp_path: Path) -> None: + (tmp_path / "go.mod").write_text("module example.com/demo\n\ngo 1.21\n", encoding="utf-8") + (tmp_path / "main.go").write_text("package main\n\nfunc main() {}\n", encoding="utf-8") + + output = _capture_show_config(monkeypatch, tmp_path) + + assert "Codeflash Configuration" in output + assert "Auto-detected (zero-config)" in output + assert "Config source:" in output + assert "Config file:" not in output + assert "go-test" in output diff --git a/tests/test_formatter.py b/tests/test_formatter.py index bd3fd7226..f235953d8 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -6,10 +6,10 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports +from codeflash.code_utils.formatter import apply_formatter_cmds, format_code, format_generated_code, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash.languages.function_optimizer import FunctionOptimizer +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash.verification.verification_utils import TestConfig @@ -1394,6 +1394,30 @@ def test_format_generated_code_unicode(): assert "Hello, 世界! 🌍" in result +def test_apply_formatter_cmds_resolves_node_wrappers_for_javascript(tmp_path: Path): + """JavaScript formatter commands should resolve npx/npm wrappers before subprocess execution.""" + from unittest.mock import MagicMock, patch + + js_file = tmp_path / "test.js" + js_file.write_text("const value = 1;\n", encoding="utf-8") + resolved_cmd = [r"C:\nvm4w\nodejs\npx.cmd", "prettier", "--write", js_file.as_posix()] + + with ( + patch( + "codeflash.languages.javascript.command_utils.resolve_node_command_list", return_value=resolved_cmd + ) as mock_resolve, + patch("codeflash.code_utils.formatter.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run, + ): + _, formatted_code, changed = apply_formatter_cmds( + ["npx prettier --write $file"], js_file, test_dir_str=None, print_status=False + ) + + mock_resolve.assert_called_once_with(["npx", "prettier", "--write", js_file.as_posix()]) + mock_run.assert_called_once_with(resolved_cmd, capture_output=True, check=False) + assert changed is True + assert formatted_code == js_file.read_text(encoding="utf-8") + + def test_format_generated_code_uses_correct_extension_for_javascript(): """Test that format_generated_code creates temp files with .js extension for JavaScript code.""" from unittest.mock import patch diff --git a/tests/test_functions_to_optimize.py b/tests/test_functions_to_optimize.py new file mode 100644 index 000000000..88c957881 --- /dev/null +++ b/tests/test_functions_to_optimize.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from argparse import Namespace +from pathlib import Path +from unittest.mock import Mock + +from codeflash.discovery.functions_to_optimize import was_function_previously_optimized + + +def test_was_function_previously_optimized_ignores_missing_git_remote(monkeypatch) -> None: + function_to_optimize = Mock(file_path=Path("example.py"), qualified_name="sorter") + code_context = Mock(hashing_code_context_hash="hash-123") + check_optimization_status = Mock() + + monkeypatch.setattr("codeflash.discovery.functions_to_optimize.is_LSP_enabled", Mock(return_value=False)) + monkeypatch.setattr("codeflash.discovery.functions_to_optimize.is_subagent_mode", Mock(return_value=False)) + monkeypatch.setattr("codeflash.discovery.functions_to_optimize.get_pr_number", Mock(return_value=123)) + monkeypatch.setattr( + "codeflash.discovery.functions_to_optimize.get_repo_owner_and_name", + Mock(side_effect=ValueError("Remote named 'origin' didn't exist")), + ) + monkeypatch.setattr( + "codeflash.discovery.functions_to_optimize.is_function_being_optimized_again", + check_optimization_status, + ) + + result = was_function_previously_optimized(function_to_optimize, code_context, Namespace(no_pr=False)) + + assert result is False + check_optimization_status.assert_not_called() diff --git a/tests/test_github_workflow.py b/tests/test_github_workflow.py new file mode 100644 index 000000000..083f463e0 --- /dev/null +++ b/tests/test_github_workflow.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +from git import Repo + +from codeflash.cli_cmds.github_workflow import install_github_actions + + +def test_install_github_actions_skip_confirm_uses_defaults_without_prompts(tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + Repo.init(tmp_path) + (tmp_path / "tests").mkdir() + (tmp_path / "pyproject.toml").write_text( + """ +[tool.codeflash] +module-root = "." +tests-root = "tests" +formatter-cmds = ["disabled"] +""".strip(), + encoding="utf-8", + ) + + with ( + patch("codeflash.cli_cmds.github_workflow.inquirer.prompt") as mock_prompt, + patch("codeflash.cli_cmds.github_workflow.get_current_branch", return_value="main"), + patch("codeflash.cli_cmds.github_workflow.get_repo_owner_and_name", side_effect=RuntimeError("no remote")), + ): + install_github_actions(skip_confirm=True) + + mock_prompt.assert_not_called() + assert (tmp_path / ".github" / "workflows" / "codeflash.yaml").exists() + + +def test_install_github_actions_skip_confirm_supports_go_projects(tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + Repo.init(tmp_path) + (tmp_path / "go.mod").write_text("module example.com/demo\n\ngo 1.21\n", encoding="utf-8") + + with ( + patch("codeflash.cli_cmds.github_workflow.inquirer.prompt") as mock_prompt, + patch("codeflash.cli_cmds.github_workflow.get_current_branch", return_value="main"), + patch("codeflash.cli_cmds.github_workflow.get_repo_owner_and_name", side_effect=RuntimeError("no remote")), + ): + install_github_actions(skip_confirm=True) + + workflow_path = tmp_path / ".github" / "workflows" / "codeflash.yaml" + workflow_text = workflow_path.read_text(encoding="utf-8") + + mock_prompt.assert_not_called() + assert workflow_path.exists() + assert "Optimize new Go code" in workflow_text + assert "actions/setup-go@v5" in workflow_text + assert "go mod download" in workflow_text diff --git a/tests/test_help_banner.py b/tests/test_help_banner.py index 27d749790..73f1cc211 100644 --- a/tests/test_help_banner.py +++ b/tests/test_help_banner.py @@ -22,3 +22,14 @@ def test_help_short_flag_displays_logo() -> None: ) assert result.returncode == 0 assert "codeflash.ai" in result.stdout + + +def test_optimize_help_displays_logo_once() -> None: + result = subprocess.run( + [sys.executable, "-c", "from codeflash.main import main; main()", "optimize", "--help"], + capture_output=True, + text=True, + encoding="utf-8", + ) + assert result.returncode == 0 + assert result.stdout.count("codeflash.ai") == 1 diff --git a/tests/test_init_auth.py b/tests/test_init_auth.py new file mode 100644 index 000000000..31c6f9d51 --- /dev/null +++ b/tests/test_init_auth.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +from codeflash.cli_cmds.init_auth import install_github_app + + +def _echo_messages(mock: Mock) -> list[str]: + return [str(call.args[0]) for call in mock.call_args_list if call.args] + + +def _mock_repo_context(monkeypatch) -> None: + git_repo = object() + monkeypatch.setattr("codeflash.cli_cmds.init_auth.git.Repo", Mock(return_value=git_repo)) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.get_git_remotes", Mock(return_value=["origin"])) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.get_repo_owner_and_name", Mock(return_value=("octocat", "demo"))) + + +def test_install_github_app_allows_user_to_skip(monkeypatch) -> None: + _mock_repo_context(monkeypatch) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.is_github_app_installed_on_repo", Mock(return_value=False)) + echo = Mock() + prompt = Mock() + launch = Mock() + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.echo", echo) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.prompt", prompt) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.launch", launch) + + with patch("rich.prompt.Confirm.ask", return_value=False) as confirm_ask: + install_github_app("origin") + + confirm_ask.assert_called_once() + prompt.assert_not_called() + launch.assert_not_called() + assert any("Skipping Codeflash GitHub app installation for octocat/demo." in message for message in _echo_messages(echo)) + + +def test_install_github_app_skips_on_noninteractive_abort(monkeypatch) -> None: + _mock_repo_context(monkeypatch) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.is_github_app_installed_on_repo", Mock(return_value=False)) + echo = Mock() + prompt = Mock() + launch = Mock() + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.echo", echo) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.prompt", prompt) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.launch", launch) + + with patch("rich.prompt.Confirm.ask", side_effect=EOFError): + install_github_app("origin") + + prompt.assert_not_called() + launch.assert_not_called() + assert any("Skipping Codeflash GitHub app installation for octocat/demo." in message for message in _echo_messages(echo)) + + +def test_install_github_app_still_runs_install_flow_when_confirmed(monkeypatch) -> None: + _mock_repo_context(monkeypatch) + monkeypatch.setattr( + "codeflash.cli_cmds.init_auth.is_github_app_installed_on_repo", + Mock(side_effect=[False, True]), + ) + echo = Mock() + prompt = Mock(return_value="") + launch = Mock() + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.echo", echo) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.prompt", prompt) + monkeypatch.setattr("codeflash.cli_cmds.init_auth.click.launch", launch) + + with patch("rich.prompt.Confirm.ask", return_value=True) as confirm_ask: + install_github_app("origin") + + confirm_ask.assert_called_once() + launch.assert_called_once_with("https://github.com/apps/codeflash-ai/installations/select_target") + assert prompt.call_count == 2 diff --git a/tests/test_init_java_go.py b/tests/test_init_java_go.py new file mode 100644 index 000000000..0fa08168c --- /dev/null +++ b/tests/test_init_java_go.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock, patch + +from codeflash.cli_cmds.init_go import collect_go_setup_info +from codeflash.cli_cmds.init_java import collect_java_setup_info, should_modify_java_config + + +def test_collect_go_setup_info_skip_confirm_uses_defaults(tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "go.mod").write_text("module example.com/demo\n\ngo 1.21\n", encoding="utf-8") + + get_git_remote = Mock(return_value="origin") + monkeypatch.setattr("codeflash.cli_cmds.init_go._get_git_remote_for_setup", get_git_remote) + monkeypatch.setattr( + "codeflash.cli_cmds.init_config.ask_for_telemetry", + Mock(side_effect=AssertionError("ask_for_telemetry should not be called")), + ) + + with patch("codeflash.cli_cmds.init_go.inquirer") as mock_inquirer: + setup_info = collect_go_setup_info(skip_confirm=True) + + mock_inquirer.prompt.assert_not_called() + get_git_remote.assert_called_once_with(skip_confirm=True) + assert setup_info.git_remote == "origin" + assert setup_info.disable_telemetry is False + + +def test_collect_java_setup_info_skip_confirm_uses_defaults(tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "build.gradle").write_text("plugins { id 'java' }\n", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + + get_git_remote = Mock(return_value="origin") + monkeypatch.setattr("codeflash.cli_cmds.init_java._get_git_remote_for_setup", get_git_remote) + monkeypatch.setattr( + "codeflash.cli_cmds.init_config.ask_for_telemetry", + Mock(side_effect=AssertionError("ask_for_telemetry should not be called")), + ) + + with patch("codeflash.cli_cmds.init_java.inquirer") as mock_inquirer: + setup_info = collect_java_setup_info(skip_confirm=True) + + mock_inquirer.prompt.assert_not_called() + get_git_remote.assert_called_once_with(skip_confirm=True) + assert setup_info.module_root_override is None + assert setup_info.test_root_override is None + assert setup_info.formatter_override is None + assert setup_info.git_remote == "origin" + assert setup_info.disable_telemetry is False + + +def test_should_modify_java_config_skip_confirm_skips_reconfigure_prompt(tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "build.gradle").write_text("plugins { id 'java' }\n", encoding="utf-8") + (tmp_path / "gradle.properties").write_text("codeflash.moduleRoot=src/main/java\n", encoding="utf-8") + + with patch("rich.prompt.Confirm.ask", side_effect=AssertionError("Confirm.ask should not be called")): + should_modify, config = should_modify_java_config(skip_confirm=True) + + assert should_modify is False + assert config is None diff --git a/tests/test_init_javascript.py b/tests/test_init_javascript.py index 194a3ed8c..acc959579 100644 --- a/tests/test_init_javascript.py +++ b/tests/test_init_javascript.py @@ -1,6 +1,7 @@ """Tests for JavaScript/TypeScript project initialization and package manager detection.""" import json +import tempfile from pathlib import Path from unittest.mock import patch @@ -17,9 +18,15 @@ @pytest.fixture -def tmp_project(tmp_path: Path) -> Path: - """Create a temporary project directory.""" - return tmp_path +def tmp_project() -> Path: + """Create a temporary project directory with a deterministic parent chain.""" + with tempfile.TemporaryDirectory(dir=Path.cwd()) as tmp_dir: + yield Path(tmp_dir) + + +def assert_install_command(actual: list[str], executable: str, expected_args: list[str]) -> None: + assert Path(actual[0]).stem.lower() == executable + assert actual[1:] == expected_args class TestDetermineJsPackageManager: @@ -206,7 +213,7 @@ def test_npm_install_command(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=True) - assert result == ["npm", "install", "codeflash", "--save-dev"] + assert_install_command(result, "npm", ["install", "codeflash", "--save-dev"]) def test_npm_install_command_non_dev(self, tmp_project: Path) -> None: """Should return npm install command without --save-dev when dev=False.""" @@ -215,7 +222,7 @@ def test_npm_install_command_non_dev(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=False) - assert result == ["npm", "install", "codeflash"] + assert_install_command(result, "npm", ["install", "codeflash"]) def test_pnpm_add_command(self, tmp_project: Path) -> None: """Should return pnpm add command for pnpm projects.""" @@ -224,7 +231,7 @@ def test_pnpm_add_command(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=True) - assert result == ["pnpm", "add", "codeflash", "--save-dev"] + assert_install_command(result, "pnpm", ["add", "codeflash", "--save-dev"]) def test_pnpm_add_command_non_dev(self, tmp_project: Path) -> None: """Should return pnpm add command without --save-dev when dev=False.""" @@ -233,7 +240,7 @@ def test_pnpm_add_command_non_dev(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=False) - assert result == ["pnpm", "add", "codeflash"] + assert_install_command(result, "pnpm", ["add", "codeflash"]) def test_yarn_add_command(self, tmp_project: Path) -> None: """Should return yarn add command for yarn projects.""" @@ -242,7 +249,7 @@ def test_yarn_add_command(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=True) - assert result == ["yarn", "add", "codeflash", "--dev"] + assert_install_command(result, "yarn", ["add", "codeflash", "--dev"]) def test_yarn_add_command_non_dev(self, tmp_project: Path) -> None: """Should return yarn add command without --dev when dev=False.""" @@ -251,7 +258,7 @@ def test_yarn_add_command_non_dev(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=False) - assert result == ["yarn", "add", "codeflash"] + assert_install_command(result, "yarn", ["add", "codeflash"]) def test_bun_add_command(self, tmp_project: Path) -> None: """Should return bun add command for bun projects.""" @@ -260,7 +267,7 @@ def test_bun_add_command(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=True) - assert result == ["bun", "add", "codeflash", "--dev"] + assert_install_command(result, "bun", ["add", "codeflash", "--dev"]) def test_bun_add_command_non_dev(self, tmp_project: Path) -> None: """Should return bun add command without --dev when dev=False.""" @@ -269,14 +276,14 @@ def test_bun_add_command_non_dev(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "codeflash", dev=False) - assert result == ["bun", "add", "codeflash"] + assert_install_command(result, "bun", ["add", "codeflash"]) def test_defaults_to_npm_for_unknown(self, tmp_project: Path) -> None: """Should default to npm for unknown package manager.""" # No lockfile, no package.json - unknown package manager result = get_package_install_command(tmp_project, "codeflash", dev=True) - assert result == ["npm", "install", "codeflash", "--save-dev"] + assert_install_command(result, "npm", ["install", "codeflash", "--save-dev"]) def test_different_package_name(self, tmp_project: Path) -> None: """Should work with different package names.""" @@ -285,7 +292,7 @@ def test_different_package_name(self, tmp_project: Path) -> None: result = get_package_install_command(tmp_project, "typescript", dev=True) - assert result == ["pnpm", "add", "typescript", "--save-dev"] + assert_install_command(result, "pnpm", ["add", "typescript", "--save-dev"]) class TestShouldModifySkipConfirm: diff --git a/tests/test_init_yes.py b/tests/test_init_yes.py new file mode 100644 index 000000000..2ed287526 --- /dev/null +++ b/tests/test_init_yes.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from argparse import Namespace +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from codeflash.cli_cmds.cmd_init import init_codeflash +from codeflash.cli_cmds.init_config import should_modify_pyproject_toml +from codeflash.cli_cmds.init_javascript import ProjectLanguage +from codeflash.main import main + + +def _args(command: str) -> Namespace: + return Namespace(command=command, yes=True, config_file=None, verify_setup=False) + + +def test_main_passes_yes_and_api_key_state_to_init(monkeypatch) -> None: + init_codeflash = Mock() + + monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test-key") + monkeypatch.setattr("codeflash.main.print_codeflash_banner", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.cli.parse_args", Mock(return_value=_args("init"))) + monkeypatch.setattr("codeflash.code_utils.version_check.check_for_newer_minor_version", Mock()) + monkeypatch.setattr("codeflash.telemetry.sentry.init_sentry", Mock()) + monkeypatch.setattr("codeflash.telemetry.posthog_cf.initialize_posthog", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.init_codeflash", init_codeflash) + + main() + + init_codeflash.assert_called_once_with(skip_confirm=True, skip_api_key=True) + + +def test_main_passes_yes_to_init_actions(monkeypatch) -> None: + install_github_actions = Mock() + + monkeypatch.setattr("codeflash.main.print_codeflash_banner", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.cli.parse_args", Mock(return_value=_args("init-actions"))) + monkeypatch.setattr("codeflash.code_utils.version_check.check_for_newer_minor_version", Mock()) + monkeypatch.setattr("codeflash.telemetry.sentry.init_sentry", Mock()) + monkeypatch.setattr("codeflash.telemetry.posthog_cf.initialize_posthog", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.github_workflow.install_github_actions", install_github_actions) + + main() + + install_github_actions.assert_called_once_with(skip_confirm=True) + + +def test_should_modify_pyproject_toml_skip_confirm_skips_reconfigure_prompt( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "src").mkdir() + (tmp_path / "tests").mkdir() + (tmp_path / "pyproject.toml").write_text( + '[tool.codeflash]\nmodule-root = "src"\ntests-root = "tests"\ngit-remote = "upstream"\n', + encoding="utf-8", + ) + + with patch("rich.prompt.Confirm.ask", side_effect=AssertionError("Confirm.ask should not be called")): + should_modify, config = should_modify_pyproject_toml(skip_confirm=True) + + assert should_modify is False + assert config is not None + assert config["git_remote"] == "upstream" + + +def test_init_codeflash_skip_confirm_reuses_existing_python_config( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "src").mkdir() + (tmp_path / "tests").mkdir() + (tmp_path / "pyproject.toml").write_text( + '[tool.codeflash]\nmodule-root = "src"\ntests-root = "tests"\ngit-remote = "upstream"\n', + encoding="utf-8", + ) + + install_github_app = Mock() + install_github_actions = Mock() + install_vscode_extension = Mock() + detect_project = Mock(side_effect=AssertionError("detect_project should not be called")) + write_config = Mock(side_effect=AssertionError("write_config should not be called")) + exit_mock = Mock(side_effect=SystemExit(0)) + + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.detect_project_language", Mock(return_value=ProjectLanguage.PYTHON)) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.install_github_app", install_github_app) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.install_github_actions", install_github_actions) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.install_vscode_extension", install_vscode_extension) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.console.print", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.ph", Mock()) + monkeypatch.setattr("codeflash.cli_cmds.cmd_init.sys.exit", exit_mock) + monkeypatch.setattr("codeflash.setup.detect_project", detect_project) + monkeypatch.setattr("codeflash.setup.write_config", write_config) + + with pytest.raises(SystemExit) as exc_info: + init_codeflash(skip_confirm=True, skip_api_key=True) + + assert exc_info.value.code == 0 + install_github_app.assert_called_once_with("upstream") + install_github_actions.assert_called_once_with(override_formatter_check=True, skip_confirm=True) + install_vscode_extension.assert_called_once() + detect_project.assert_not_called() + write_config.assert_not_called() + exit_mock.assert_called_once_with(0) diff --git a/tests/test_languages/test_golang/test_config.py b/tests/test_languages/test_golang/test_config.py index c42e3cada..6a7f906a7 100644 --- a/tests/test_languages/test_golang/test_config.py +++ b/tests/test_languages/test_golang/test_config.py @@ -2,6 +2,7 @@ from pathlib import Path +from codeflash.code_utils.config_parser import parse_config_file from codeflash.languages.golang.config import detect_go_project, is_go_project FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "go_project" @@ -45,3 +46,19 @@ def test_without_go_files(self, tmp_path: Path) -> None: def test_with_go_files_no_mod(self, tmp_path: Path) -> None: (tmp_path / "main.go").write_text("package main\n", encoding="utf-8") assert is_go_project(tmp_path) is True + + +class TestParseGoConfig: + def test_parse_config_file_uses_go_test_metadata(self, tmp_path: Path, monkeypatch) -> None: + (tmp_path / "go.mod").write_text("module example.com/minimal\n\ngo 1.21\n", encoding="utf-8") + (tmp_path / "main.go").write_text("package main\n\nfunc main() {}\n", encoding="utf-8") + + monkeypatch.chdir(tmp_path) + config, config_path = parse_config_file() + + assert config_path == tmp_path + assert config["language"] == "go" + assert config["module_root"] == str(tmp_path.resolve()) + assert config["tests_root"] == str(tmp_path.resolve()) + assert config["pytest_cmd"] == "go test ./..." + assert config["test_framework"] == "go-test" diff --git a/tests/test_languages/test_golang/test_support.py b/tests/test_languages/test_golang/test_support.py index 5c415c78e..d5e3cb222 100644 --- a/tests/test_languages/test_golang/test_support.py +++ b/tests/test_languages/test_golang/test_support.py @@ -1,6 +1,9 @@ from __future__ import annotations from pathlib import Path +from unittest.mock import Mock + +import pytest from codeflash.languages.golang.support import GoSupport from codeflash.languages.language_enum import Language @@ -45,6 +48,17 @@ def test_dir_excludes(self) -> None: assert "vendor" in support.dir_excludes assert "testdata" in support.dir_excludes + def test_analyzer_initialization_is_lazy(self, monkeypatch) -> None: + analyzer_cls = Mock(side_effect=ModuleNotFoundError("tree_sitter_go missing")) + support = GoSupport() + + monkeypatch.setattr("codeflash.languages.golang.support.GoAnalyzer", analyzer_cls) + + with pytest.raises(ModuleNotFoundError, match="tree-sitter Go parser"): + support.validate_syntax("package main\n\nfunc main() {}\n") + + analyzer_cls.assert_called_once_with() + class TestGoSupportRegistration: def test_lookup_by_language_enum(self) -> None: diff --git a/tests/test_languages/test_java/test_build_config_strategy.py b/tests/test_languages/test_java/test_build_config_strategy.py index 15effd60b..87df4032f 100644 --- a/tests/test_languages/test_java/test_build_config_strategy.py +++ b/tests/test_languages/test_java/test_build_config_strategy.py @@ -386,7 +386,25 @@ def test_raises_for_unknown(self, tmp_path: Path) -> None: class TestParseJavaProjectConfig: def test_standard_maven_project(self, tmp_path: Path) -> None: - (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "pom.xml").write_text( + """ + + 4.0.0 + com.example + demo + 1.0.0 + + + org.junit.jupiter + junit-jupiter + 5.10.0 + test + + + +""", + encoding="utf-8", + ) src = tmp_path / "src" / "main" / "java" src.mkdir(parents=True) test = tmp_path / "src" / "test" / "java" @@ -397,9 +415,26 @@ def test_standard_maven_project(self, tmp_path: Path) -> None: assert config["language"] == "java" assert config["module_root"] == str(src) assert config["tests_root"] == str(test) + assert config["pytest_cmd"] == "mvn test" + assert config["test_framework"] == "junit5" def test_standard_gradle_project(self, tmp_path: Path) -> None: - (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "build.gradle").write_text( + """ +plugins { + id 'java' +} + +dependencies { + testImplementation 'org.junit.jupiter:junit-jupiter:5.10.0' +} + +test { + useJUnitPlatform() +} +""", + encoding="utf-8", + ) src = tmp_path / "src" / "main" / "java" src.mkdir(parents=True) test = tmp_path / "src" / "test" / "java" @@ -408,6 +443,8 @@ def test_standard_gradle_project(self, tmp_path: Path) -> None: config = parse_java_project_config(tmp_path) assert config is not None assert config["language"] == "java" + assert config["pytest_cmd"] == "./gradlew test" + assert config["test_framework"] == "junit5" def test_returns_none_for_non_java(self, tmp_path: Path) -> None: assert parse_java_project_config(tmp_path) is None diff --git a/tests/test_languages/test_javascript_requirements.py b/tests/test_languages/test_javascript_requirements.py index efefda228..00ba2967e 100644 --- a/tests/test_languages/test_javascript_requirements.py +++ b/tests/test_languages/test_javascript_requirements.py @@ -4,7 +4,7 @@ """ import json -from pathlib import Path +from pathlib import Path, PureWindowsPath from unittest.mock import MagicMock, patch import pytest @@ -15,6 +15,12 @@ class TestVerifyRequirements: """Tests for JavaScriptSupport.verify_requirements().""" + @staticmethod + def _command_name(command: str) -> str: + if "\\" in command: + return PureWindowsPath(command).stem.lower() + return Path(command).stem.lower() + @pytest.fixture def js_support(self): """Create a JavaScriptSupport instance.""" @@ -98,9 +104,10 @@ def test_verify_requirements_fails_without_npm(self, js_support, project_with_je """Test verification fails when npm is not available.""" def mock_run_side_effect(cmd, **kwargs): - if cmd[0] == "node": + command_name = self._command_name(cmd[0]) + if command_name == "node": return MagicMock(returncode=0) - if cmd[0] == "npm": + if command_name == "npm": raise FileNotFoundError("npm not found") return MagicMock(returncode=0) @@ -111,6 +118,26 @@ def mock_run_side_effect(cmd, **kwargs): npm_error_found = any("npm" in error.message for error in errors) assert npm_error_found is True + def test_verify_requirements_accepts_windows_cmd_wrappers(self, js_support, project_with_jest): + resolved_commands = {"node": r"C:\nvm4w\nodejs\node.exe", "npm": r"C:\nvm4w\nodejs\npm.cmd"} + + def mock_run_side_effect(cmd, **kwargs): + command_name = self._command_name(cmd[0]) + assert cmd[0] == resolved_commands[command_name] + return MagicMock(returncode=0) + + with ( + patch( + "codeflash.languages.javascript.support.resolve_node_command", + side_effect=lambda command: resolved_commands[command], + ), + patch("subprocess.run", side_effect=mock_run_side_effect), + ): + success, errors = js_support.verify_requirements(project_with_jest, "jest") + + assert success is True + assert errors == [] + def test_verify_requirements_fails_without_node_modules(self, js_support, project_without_node_modules): """Test verification fails when node_modules doesn't exist.""" with patch("subprocess.run") as mock_run: diff --git a/tests/test_languages/test_mocha_runner.py b/tests/test_languages/test_mocha_runner.py index 15e470d9b..96968ed87 100644 --- a/tests/test_languages/test_mocha_runner.py +++ b/tests/test_languages/test_mocha_runner.py @@ -329,7 +329,7 @@ def test_basic_command(self): test_file.write_text("// test") cmd = _build_mocha_behavioral_command(test_files=[test_file]) - assert "npx" in cmd + assert Path(cmd[0]).stem.lower() == "npx" assert "mocha" in cmd assert "--reporter" in cmd assert "json" in cmd @@ -371,7 +371,7 @@ def test_basic_command(self): test_file.write_text("// test") cmd = _build_mocha_benchmarking_command(test_files=[test_file]) - assert "npx" in cmd + assert Path(cmd[0]).stem.lower() == "npx" assert "mocha" in cmd assert "--exit" in cmd @@ -398,7 +398,7 @@ def test_basic_command(self): test_file.write_text("// test") cmd = _build_mocha_line_profile_command(test_files=[test_file]) - assert "npx" in cmd + assert Path(cmd[0]).stem.lower() == "npx" assert "mocha" in cmd assert "--exit" in cmd diff --git a/tests/test_optimize_help.py b/tests/test_optimize_help.py new file mode 100644 index 000000000..dcfcf0598 --- /dev/null +++ b/tests/test_optimize_help.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import sys +from argparse import Namespace +from unittest.mock import Mock + +from codeflash import tracer + + +def test_optimize_help_shows_tracer_help_for_javascript_projects(monkeypatch, tmp_path, capsys) -> None: + (tmp_path / "package.json").write_text('{"name": "js-app"}', encoding="utf-8") + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(sys, "argv", ["codeflash", "--help"]) + + tracer.main(Namespace(file=None)) + + captured = capsys.readouterr() + assert "--only-functions" in captured.out + assert "Sub-commands" not in captured.out + + +def test_optimize_help_does_not_start_java_tracing(monkeypatch, tmp_path, capsys) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + run_java_tracer = Mock() + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(sys, "argv", ["codeflash", "--help"]) + monkeypatch.setattr("codeflash.tracer._run_java_tracer", run_java_tracer) + + tracer.main(Namespace(file=None)) + + captured = capsys.readouterr() + run_java_tracer.assert_not_called() + assert "--only-functions" in captured.out + assert "No Java command provided" not in captured.out diff --git a/tests/test_optimizer_signals.py b/tests/test_optimizer_signals.py new file mode 100644 index 000000000..526b384d6 --- /dev/null +++ b/tests/test_optimizer_signals.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import signal +from argparse import Namespace +from unittest.mock import Mock, call + +from codeflash.optimization import optimizer as optimizer_module + + +def test_run_with_args_skips_unavailable_optional_signals(monkeypatch) -> None: + cleanup_stale_worktrees = Mock() + optimizer_instance = Mock(current_worktree=None) + optimizer_class = Mock(return_value=optimizer_instance) + getsignal = Mock(return_value="original-handler") + signal_calls: list[tuple[object, object]] = [] + + def fake_signal(signum: object, handler: object) -> None: + signal_calls.append((signum, handler)) + + monkeypatch.setattr(optimizer_module, "cleanup_stale_worktrees", cleanup_stale_worktrees) + monkeypatch.setattr(optimizer_module, "Optimizer", optimizer_class) + monkeypatch.setattr(signal, "getsignal", getsignal) + monkeypatch.setattr(signal, "signal", fake_signal) + monkeypatch.delattr(signal, "SIGHUP", raising=False) + monkeypatch.delattr(signal, "SIGQUIT", raising=False) + monkeypatch.delattr(signal, "SIGPIPE", raising=False) + + optimizer_module.run_with_args(Namespace()) + + cleanup_stale_worktrees.assert_called_once_with() + optimizer_class.assert_called_once() + optimizer_instance.run.assert_called_once_with() + assert getsignal.call_args_list == [call(signal.SIGTERM)] + assert [signum for signum, _ in signal_calls] == [signal.SIGTERM, signal.SIGTERM] diff --git a/tests/test_posthog_cf.py b/tests/test_posthog_cf.py new file mode 100644 index 000000000..80f204370 --- /dev/null +++ b/tests/test_posthog_cf.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from unittest.mock import Mock + +from codeflash.telemetry import posthog_cf + + +def test_initialize_posthog_ignores_user_lookup_failures(monkeypatch) -> None: + fake_client = Mock() + fake_client.log = Mock() + + monkeypatch.setattr("posthog.Posthog", lambda *args, **kwargs: fake_client) + monkeypatch.setattr("codeflash.api.cfapi.get_user_id", Mock(side_effect=SystemExit(1))) + monkeypatch.setattr(posthog_cf, "_posthog", None) + + posthog_cf.initialize_posthog(enabled=True) + + fake_client.capture.assert_not_called() + + +def test_ph_ignores_capture_failures(monkeypatch) -> None: + fake_client = Mock() + fake_client.capture.side_effect = RuntimeError("capture failed") + + monkeypatch.setattr(posthog_cf, "_posthog", fake_client) + monkeypatch.setattr("codeflash.api.cfapi.get_user_id", Mock(return_value="user-123")) + + posthog_cf.ph("cli-test-event") + + fake_client.capture.assert_called_once() + + +def test_ph_uses_silent_user_lookup(monkeypatch) -> None: + fake_client = Mock() + get_user_id = Mock(return_value=None) + + monkeypatch.setattr(posthog_cf, "_posthog", fake_client) + monkeypatch.setattr("codeflash.api.cfapi.get_user_id", get_user_id) + + posthog_cf.ph("cli-test-event") + + get_user_id.assert_called_once_with(suppress_errors=True) diff --git a/tests/test_setup/test_config.py b/tests/test_setup/test_config.py index 6d870be90..662e94e1c 100644 --- a/tests/test_setup/test_config.py +++ b/tests/test_setup/test_config.py @@ -208,6 +208,27 @@ def test_updates_existing_codeflash_section(self, tmp_path): assert data["tool"]["codeflash"]["module-root"] == "new" assert data["tool"]["codeflash"]["tests-root"] == "new_tests" + def test_preserves_crlf_newlines_for_existing_pyproject(self, tmp_path): + """Should not introduce doubled carriage returns when preserving CRLF files.""" + pyproject_path = tmp_path / "pyproject.toml" + pyproject_path.write_bytes(b'[project]\r\nname = "myapp"\r\n\r\n[tool.ruff]\r\nline-length = 120\r\n') + + config = CodeflashConfig(language="python", module_root="src") + + success, message = _write_pyproject_toml(tmp_path, config) + + assert success is True + assert message == f"Config saved to {pyproject_path}" + + content = pyproject_path.read_bytes() + assert b"\r\r\n" not in content + assert b"\r\n" in content + + data = tomlkit.parse(content) + assert data["project"]["name"] == "myapp" + assert data["tool"]["ruff"]["line-length"] == 120 + assert data["tool"]["codeflash"]["module-root"] == "src" + class TestWritePackageJson: """Tests for writing to package.json.""" diff --git a/tests/verification/test_verifier_path_handling.py b/tests/verification/test_verifier_path_handling.py index 2b5ffb772..6a24aa883 100644 --- a/tests/verification/test_verifier_path_handling.py +++ b/tests/verification/test_verifier_path_handling.py @@ -8,10 +8,13 @@ """ from pathlib import Path +from unittest.mock import MagicMock, patch import pytest from codeflash.code_utils.code_utils import module_name_from_file_path +from codeflash.models.function_types import FunctionToOptimize +from codeflash.verification.verifier import generate_tests class TestVerifierPathHandling: @@ -53,3 +56,54 @@ def test_module_name_from_file_path_with_fallback_succeeds(self) -> None: # After fallback, we should have a valid path assert test_module_path == "test_foo.test.ts" + + def test_generate_tests_uses_forward_slashes_for_javascript_module_paths(self, tmp_path: Path) -> None: + """Generated JS import paths should stay valid on Windows by using forward slashes.""" + project_root = tmp_path / "project" + source_dir = project_root / "src" + source_dir.mkdir(parents=True) + + source_file = source_dir / "async_utils.js" + source_file.write_text("export async function processItemsSequential() {}", encoding="utf-8") + + generated_tests_dir = source_dir / "__tests__" / "codeflash-generated" + generated_tests_dir.mkdir(parents=True) + test_path = generated_tests_dir / "test_processItemsSequential__unit_test_0.test.js" + test_perf_path = generated_tests_dir / "test_processItemsSequential__perf_test_0.test.js" + + function_to_optimize = FunctionToOptimize( + function_name="processItemsSequential", file_path=source_file, language="javascript" + ) + test_cfg = MagicMock(tests_project_rootdir=project_root / "tests", test_framework="jest") + ai_client = MagicMock() + ai_client.generate_regression_tests.return_value = ("generated", "behavior", "perf", None) + + mock_support = MagicMock() + mock_support.detect_module_system.return_value = "esm" + mock_support.language_version = None + mock_support.process_generated_test_strings.side_effect = lambda **kwargs: ( + kwargs["generated_test_source"], + kwargs["instrumented_behavior_test_source"], + kwargs["instrumented_perf_test_source"], + ) + + with patch("codeflash.verification.verifier.current_language_support", return_value=mock_support): + result = generate_tests( + aiservice_client=ai_client, + source_code_being_tested=source_file.read_text(encoding="utf-8"), + function_to_optimize=function_to_optimize, + helper_function_names=[], + module_path=source_file, + test_cfg=test_cfg, + test_timeout=30, + function_trace_id="trace-id", + test_index=0, + test_path=test_path, + test_perf_path=test_perf_path, + ) + + assert result is not None + module_path = ai_client.generate_regression_tests.call_args.kwargs["module_path"] + assert module_path == "../../async_utils.js" + assert "\\" not in module_path + assert not module_path.startswith("./..")