diff --git a/codeflash/code_utils/git_worktree_utils.py b/codeflash/code_utils/git_worktree_utils.py index 3dcba708e..1e7738763 100644 --- a/codeflash/code_utils/git_worktree_utils.py +++ b/codeflash/code_utils/git_worktree_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import configparser +import os import shutil import stat import subprocess @@ -65,6 +66,10 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]: repository.git.worktree("add", "-d", str(worktree_dir)) + # Write PID file so stale worktrees can be detected after SIGKILL + pid_file = worktree_dir / ".codeflash.pid" + pid_file.write_text(str(os.getpid()), encoding="utf-8") + # Get uncommitted diff from the original repo repository.git.add("-N", ".") # add the index for untracked files to be included in the diff exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off @@ -119,6 +124,36 @@ def remove_worktree(worktree_dir: Path) -> None: logger.exception(f"Failed to remove worktree: {worktree_dir}") +def is_process_alive(pid: int) -> bool: + try: + os.kill(pid, 0) + except ProcessLookupError: + return False + except PermissionError: + return True # process exists but we can't signal it + return True + + +def cleanup_stale_worktrees() -> None: + """Remove worktrees left behind by killed processes (e.g. SIGKILL).""" + if not worktree_dirs.exists(): + return + for entry in worktree_dirs.iterdir(): + if not entry.is_dir(): + continue + pid_file = entry / ".codeflash.pid" + if pid_file.exists(): + try: + pid = int(pid_file.read_text(encoding="utf-8").strip()) + except (ValueError, OSError): + pid = None + if pid is not None and is_process_alive(pid): + continue # worktree is still in use + # No PID file or owning process is dead — stale worktree + logger.info(f"Removing stale worktree: {entry}") + remove_worktree(entry) + + def create_diff_patch_from_worktree( worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None ) -> Optional[Path]: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 21fe83ff2..82ae5520a 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -22,6 +22,7 @@ from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir from codeflash.code_utils.git_worktree_utils import ( + cleanup_stale_worktrees, create_detached_worktree, create_diff_patch_from_worktree, create_worktree_snapshot_commit, @@ -735,7 +736,33 @@ def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path: def run_with_args(args: Namespace) -> None: + import atexit + import signal + + cleanup_stale_worktrees() + optimizer = None + original_sigterm = signal.getsignal(signal.SIGTERM) + original_sighup = signal.getsignal(signal.SIGHUP) + original_sigquit = signal.getsignal(signal.SIGQUIT) + original_sigpipe = signal.getsignal(signal.SIGPIPE) + + def cleanup_worktree_on_exit() -> None: + if optimizer and optimizer.current_worktree: + remove_worktree(optimizer.current_worktree) + + def signal_handler(signum: int, frame: object) -> None: + logger.warning(f"Signal {signum} received. Cleaning up worktree and exiting…") + if optimizer: + optimizer.cleanup_temporary_paths() + raise SystemExit(128 + signum) + + atexit.register(cleanup_worktree_on_exit) + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) + signal.signal(signal.SIGQUIT, signal_handler) + signal.signal(signal.SIGPIPE, signal_handler) + try: optimizer = Optimizer(args) optimizer.run() @@ -745,3 +772,9 @@ def run_with_args(args: Namespace) -> None: optimizer.cleanup_temporary_paths() raise SystemExit from None + finally: + atexit.unregister(cleanup_worktree_on_exit) + signal.signal(signal.SIGTERM, original_sigterm) + signal.signal(signal.SIGHUP, original_sighup) + signal.signal(signal.SIGQUIT, original_sigquit) + signal.signal(signal.SIGPIPE, original_sigpipe)