Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions codeflash/code_utils/git_worktree_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import configparser
import os
import shutil
import stat
import subprocess
Expand Down Expand Up @@ -65,6 +66,10 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:

repository.git.worktree("add", "-d", str(worktree_dir))

# Write PID file so stale worktrees can be detected after SIGKILL
pid_file = worktree_dir / ".codeflash.pid"
pid_file.write_text(str(os.getpid()), encoding="utf-8")

# Get uncommitted diff from the original repo
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
Expand Down Expand Up @@ -119,6 +124,36 @@ def remove_worktree(worktree_dir: Path) -> None:
logger.exception(f"Failed to remove worktree: {worktree_dir}")


def is_process_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True # process exists but we can't signal it
return True


def cleanup_stale_worktrees() -> None:
"""Remove worktrees left behind by killed processes (e.g. SIGKILL)."""
if not worktree_dirs.exists():
return
for entry in worktree_dirs.iterdir():
if not entry.is_dir():
continue
pid_file = entry / ".codeflash.pid"
if pid_file.exists():
try:
pid = int(pid_file.read_text(encoding="utf-8").strip())
except (ValueError, OSError):
pid = None
if pid is not None and is_process_alive(pid):
continue # worktree is still in use
# No PID file or owning process is dead — stale worktree
logger.info(f"Removing stale worktree: {entry}")
remove_worktree(entry)


def create_diff_patch_from_worktree(
worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None
) -> Optional[Path]:
Expand Down
33 changes: 33 additions & 0 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
from codeflash.code_utils.git_worktree_utils import (
cleanup_stale_worktrees,
create_detached_worktree,
create_diff_patch_from_worktree,
create_worktree_snapshot_commit,
Expand Down Expand Up @@ -735,7 +736,33 @@ def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path:


def run_with_args(args: Namespace) -> None:
import atexit
import signal

cleanup_stale_worktrees()

optimizer = None
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sighup = signal.getsignal(signal.SIGHUP)
original_sigquit = signal.getsignal(signal.SIGQUIT)
original_sigpipe = signal.getsignal(signal.SIGPIPE)

def cleanup_worktree_on_exit() -> None:
if optimizer and optimizer.current_worktree:
remove_worktree(optimizer.current_worktree)

def signal_handler(signum: int, frame: object) -> None:
logger.warning(f"Signal {signum} received. Cleaning up worktree and exiting…")
if optimizer:
optimizer.cleanup_temporary_paths()
raise SystemExit(128 + signum)

atexit.register(cleanup_worktree_on_exit)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)
signal.signal(signal.SIGQUIT, signal_handler)
signal.signal(signal.SIGPIPE, signal_handler)

try:
optimizer = Optimizer(args)
optimizer.run()
Expand All @@ -745,3 +772,9 @@ def run_with_args(args: Namespace) -> None:
optimizer.cleanup_temporary_paths()

raise SystemExit from None
finally:
atexit.unregister(cleanup_worktree_on_exit)
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGHUP, original_sighup)
signal.signal(signal.SIGQUIT, original_sigquit)
signal.signal(signal.SIGPIPE, original_sigpipe)
Loading