diff --git a/packages/claude-code-plugin/hooks/lib/hook_runtime.py b/packages/claude-code-plugin/hooks/lib/hook_runtime.py new file mode 100644 index 00000000..c778c085 --- /dev/null +++ b/packages/claude-code-plugin/hooks/lib/hook_runtime.py @@ -0,0 +1,52 @@ +"""Hook execution timing context manager (#1494). + +Records hook elapsed time via SessionStats.record_hook_timing(), flushing +to disk so the Stop hook summary can surface the ⏱ timing report. +""" +import time +from contextlib import contextmanager +from typing import Optional + + +@contextmanager +def time_hook( + hook_name: str, + *, + session_id: Optional[str] = None, + data_dir: Optional[str] = None, +): + """Context manager that records hook execution time. + + Wraps hook logic with monotonic clock measurements and persists the + elapsed milliseconds to disk via SessionStats. Recording failures + are silently swallowed so hook execution is never blocked. + + Args: + hook_name: Claude Code event name (e.g. 'PostToolUse'). + session_id: Explicit session ID. If None, resolved via + session_utils.get_session_id(). + data_dir: Stats directory override (mainly for tests). + + Usage:: + + with time_hook("PostToolUse"): + # ... hook logic ... + """ + start = time.monotonic() + try: + yield + finally: + try: + elapsed_ms = (time.monotonic() - start) * 1000 + if session_id is None: + from session_utils import get_session_id + session_id = get_session_id() + from stats import SessionStats + kwargs = {"session_id": session_id} + if data_dir is not None: + kwargs["data_dir"] = data_dir + stats = SessionStats(**kwargs) + stats.record_hook_timing(hook_name, elapsed_ms) + stats.flush() + except Exception: + pass # Never block tool execution diff --git a/packages/claude-code-plugin/hooks/lib/hook_timer.py b/packages/claude-code-plugin/hooks/lib/hook_timer.py deleted file mode 100644 index 0d28c0c9..00000000 --- a/packages/claude-code-plugin/hooks/lib/hook_timer.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Adaptive hook timeout tracking (#945). - -Tracks hook execution time and provides statistics and warnings -for hooks approaching their timeout threshold. -""" -import time -from typing import Dict, List - - -class HookTimer: - """Track hook execution times and provide performance statistics.""" - - def __init__(self) -> None: - self._active: Dict[str, float] = {} - self._timings: Dict[str, List[float]] = {} - - def start(self, hook_name: str) -> None: - """Start timing a hook execution. - - Args: - hook_name: Name of the hook being timed. - """ - self._active[hook_name] = time.monotonic() - - def stop(self, hook_name: str) -> float: - """Stop timing a hook and record the duration. - - Args: - hook_name: Name of the hook to stop timing. - - Returns: - Elapsed time in milliseconds. - - Raises: - ValueError: If no active timer exists for the hook. - """ - if hook_name not in self._active: - raise ValueError(f"No active timer for hook: {hook_name}") - elapsed_ms = (time.monotonic() - self._active.pop(hook_name)) * 1000 - if hook_name not in self._timings: - self._timings[hook_name] = [] - self._timings[hook_name].append(elapsed_ms) - return elapsed_ms - - def get_stats(self) -> Dict[str, Dict[str, float]]: - """Compute statistics for all tracked hooks. - - Returns: - Dict mapping hook_name to {count, avg_ms, p95_ms, max_ms}. - """ - result: Dict[str, Dict[str, float]] = {} - for hook_name, timings in self._timings.items(): - sorted_t = sorted(timings) - count = len(sorted_t) - avg_ms = sum(sorted_t) / count - p95_idx = int(count * 0.95) - if p95_idx >= count: - p95_idx = count - 1 - result[hook_name] = { - "count": count, - "avg_ms": round(avg_ms, 2), - "p95_ms": round(sorted_t[p95_idx], 2), - "max_ms": round(sorted_t[-1], 2), - } - return result - - def get_warnings(self, timeout_ms: int = 10000) -> List[str]: - """Return warnings for hooks using >=80% of their timeout. - - Args: - timeout_ms: Hook timeout in milliseconds (default 10000). - - Returns: - List of warning strings for slow hooks. - """ - threshold = timeout_ms * 0.8 - warnings: List[str] = [] - for hook_name, timings in self._timings.items(): - max_time = max(timings) - if max_time >= threshold: - pct = max_time / timeout_ms * 100 - warnings.append( - f"Hook '{hook_name}' used {max_time:.0f}ms " - f"of {timeout_ms}ms timeout ({pct:.0f}%)" - ) - return warnings diff --git a/packages/claude-code-plugin/hooks/lib/stats.py b/packages/claude-code-plugin/hooks/lib/stats.py index 275552ec..b89b08d5 100644 --- a/packages/claude-code-plugin/hooks/lib/stats.py +++ b/packages/claude-code-plugin/hooks/lib/stats.py @@ -96,24 +96,38 @@ def record_tool_call(self, tool_name: str, success: bool = True) -> None: self.flush() def flush(self) -> None: - """Flush accumulated in-memory stats to disk.""" + """Flush accumulated in-memory stats to disk. + + Uses _locked_modify to perform atomic read-modify-write inside a + single LOCK_EX window, preventing lost updates from concurrent + processes (#1493). + """ if self._pending_count == 0: return - data = self._locked_read() - data["tool_count"] = data.get("tool_count", 0) + self._mem_tool_count - data["error_count"] = data.get("error_count", 0) + self._mem_error_count - tool_names = data.get("tool_names", {}) - for name, count in self._mem_tool_names.items(): - tool_names[name] = tool_names.get(name, 0) + count - data["tool_names"] = tool_names - # Merge hook timings - hook_timings = data.get("hook_timings", {}) - for name, times in self._mem_hook_timings.items(): - if name not in hook_timings: - hook_timings[name] = [] - hook_timings[name].extend(times) - data["hook_timings"] = hook_timings - self._locked_write(data) + + # Capture deltas before entering critical section + delta_tool_count = self._mem_tool_count + delta_error_count = self._mem_error_count + delta_tool_names = dict(self._mem_tool_names) + delta_hook_timings = {k: list(v) for k, v in self._mem_hook_timings.items()} + + def apply_deltas(data: Dict[str, Any]) -> Dict[str, Any]: + data["tool_count"] = data.get("tool_count", 0) + delta_tool_count + data["error_count"] = data.get("error_count", 0) + delta_error_count + tool_names = data.get("tool_names", {}) + for name, count in delta_tool_names.items(): + tool_names[name] = tool_names.get(name, 0) + count + data["tool_names"] = tool_names + hook_timings = data.get("hook_timings", {}) + for name, times in delta_hook_timings.items(): + if name not in hook_timings: + hook_timings[name] = [] + hook_timings[name].extend(times) + data["hook_timings"] = hook_timings + return data + + self._locked_modify(apply_deltas) + # Reset in-memory accumulators self._mem_tool_count = 0 self._mem_error_count = 0 @@ -239,6 +253,47 @@ def cleanup_stale(data_dir: str, max_age_hours: int = 24) -> None: except OSError: pass + def _locked_modify(self, mutator: Any) -> None: + """Atomic read-modify-write inside a single LOCK_EX window (#1493). + + Opens the stats file with exclusive lock, reads current data, + applies *mutator(data) -> data*, then writes back — all without + releasing the lock. This prevents the lost-update race where + concurrent processes each read the same baseline. + + Args: + mutator: Callable (Dict -> Dict) that transforms the data + dict in place or returns the updated dict. + + Note: When HAS_FCNTL is False (non-Unix platforms), locking is + skipped entirely. Concurrent flushes on such platforms may lose + updates — this is a known limitation documented here for + visibility. + """ + seed: Dict[str, Any] = { + "session_id": self.session_id, + "started_at": time.time(), + "tool_count": 0, + "error_count": 0, + "tool_names": {}, + "hook_timings": {}, + } + try: + fd = os.open(self.stats_file, os.O_RDWR | os.O_CREAT) + with os.fdopen(fd, "r+", encoding="utf-8") as f: + if HAS_FCNTL: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + raw = f.read() + data = json.loads(raw) if raw else dict(seed) + data = mutator(data) + f.seek(0) + f.truncate() + json.dump(data, f) + except (json.JSONDecodeError, OSError): + # File corrupted or missing — write seed with deltas applied + data = mutator(dict(seed)) + self._locked_write(data) + def _locked_read(self) -> Dict[str, Any]: """Read stats file with file locking.""" try: diff --git a/packages/claude-code-plugin/hooks/post-tool-use.py b/packages/claude-code-plugin/hooks/post-tool-use.py index 75347416..0d02f982 100644 --- a/packages/claude-code-plugin/hooks/post-tool-use.py +++ b/packages/claude-code-plugin/hooks/post-tool-use.py @@ -26,6 +26,13 @@ def handle_post_tool_use(data: dict): Records tool call stats (#825). Future: history tracking (#827). """ + from hook_runtime import time_hook + with time_hook("PostToolUse"): + return _handle_post_tool_use(data) + + +def _handle_post_tool_use(data: dict): + """Core PostToolUse logic, wrapped by time_hook.""" try: from stats import SessionStats diff --git a/packages/claude-code-plugin/hooks/pre-tool-use.py b/packages/claude-code-plugin/hooks/pre-tool-use.py index 2630ab88..55bc5331 100644 --- a/packages/claude-code-plugin/hooks/pre-tool-use.py +++ b/packages/claude-code-plugin/hooks/pre-tool-use.py @@ -289,7 +289,9 @@ def _handle(data: dict) -> Optional[dict]: @safe_main def handle_pre_tool_use(data: dict) -> Optional[dict]: """Entry point for PreToolUse hook.""" - return _handle(data) + from hook_runtime import time_hook + with time_hook("PreToolUse"): + return _handle(data) if __name__ == "__main__": diff --git a/packages/claude-code-plugin/hooks/session-start.py b/packages/claude-code-plugin/hooks/session-start.py index 90715387..7aec4ed3 100644 --- a/packages/claude-code-plugin/hooks/session-start.py +++ b/packages/claude-code-plugin/hooks/session-start.py @@ -887,6 +887,14 @@ def _check_briefing_recovery() -> None: def main(): """Main entry point for the session start hook.""" + _ensure_lib_path() + from hook_runtime import time_hook + with time_hook("SessionStart"): + return _main_inner() + + +def _main_inner(): + """Core session-start logic, wrapped by time_hook.""" try: home = Path.home() hooks_dir = home / ".claude" / "hooks" diff --git a/packages/claude-code-plugin/hooks/stop.py b/packages/claude-code-plugin/hooks/stop.py index 9006b2fd..0e89e784 100644 --- a/packages/claude-code-plugin/hooks/stop.py +++ b/packages/claude-code-plugin/hooks/stop.py @@ -23,6 +23,13 @@ def handle_stop(data: dict): Finalizes session stats and returns a systemMessage summary. """ + from hook_runtime import time_hook + with time_hook("Stop"): + return _handle_stop(data) + + +def _handle_stop(data: dict): + """Core Stop logic, wrapped by time_hook.""" try: from stats import SessionStats diff --git a/packages/claude-code-plugin/hooks/user-prompt-submit.py b/packages/claude-code-plugin/hooks/user-prompt-submit.py index df42ec46..096181ef 100644 --- a/packages/claude-code-plugin/hooks/user-prompt-submit.py +++ b/packages/claude-code-plugin/hooks/user-prompt-submit.py @@ -51,6 +51,18 @@ def detect_mode(prompt: str) -> Optional[str]: def main(): """Main entry point for the hook.""" + # Ensure hooks/lib is importable for time_hook + _hooks_dir = os.path.dirname(os.path.abspath(__file__)) + _lib_dir = os.path.join(_hooks_dir, "lib") + if _lib_dir not in sys.path: + sys.path.insert(0, _lib_dir) + from hook_runtime import time_hook + with time_hook("UserPromptSubmit"): + return _main_inner() + + +def _main_inner(): + """Core UserPromptSubmit logic, wrapped by time_hook.""" try: # Read input from stdin input_data = json.load(sys.stdin) diff --git a/packages/claude-code-plugin/tests/test_hook_timer.py b/packages/claude-code-plugin/tests/test_hook_timer.py deleted file mode 100644 index a0440889..00000000 --- a/packages/claude-code-plugin/tests/test_hook_timer.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for HookTimer — adaptive hook timeout tracking (#945).""" -import os -import sys -import time - -import pytest - -# Ensure hooks/lib is on path -_tests_dir = os.path.dirname(os.path.abspath(__file__)) -_lib_dir = os.path.join(os.path.dirname(_tests_dir), "hooks", "lib") -if _lib_dir not in sys.path: - sys.path.insert(0, _lib_dir) - -from hook_timer import HookTimer - - -@pytest.fixture -def timer(): - return HookTimer() - - -class TestStartStop: - """Test basic start/stop timing functionality.""" - - def test_start_stop_returns_elapsed_ms(self, timer): - """stop() should return elapsed time in milliseconds.""" - timer.start("my_hook") - time.sleep(0.05) - elapsed = timer.stop("my_hook") - assert elapsed >= 40 # at least ~40ms (sleep 50ms with tolerance) - assert elapsed < 200 # not unreasonably long - - def test_stop_without_start_raises(self, timer): - """stop() without matching start() should raise ValueError.""" - with pytest.raises(ValueError, match="No active timer"): - timer.stop("nonexistent_hook") - - def test_multiple_hooks_tracked_independently(self, timer): - """Different hooks should track independently.""" - timer.start("hook_a") - time.sleep(0.02) - timer.start("hook_b") - time.sleep(0.02) - elapsed_a = timer.stop("hook_a") - elapsed_b = timer.stop("hook_b") - # hook_a ran longer than hook_b - assert elapsed_a > elapsed_b - - def test_same_hook_multiple_times(self, timer): - """Same hook can be started/stopped multiple times, recording each.""" - timer.start("hook_a") - timer.stop("hook_a") - timer.start("hook_a") - timer.stop("hook_a") - stats = timer.get_stats() - assert stats["hook_a"]["count"] == 2 - - -class TestGetStats: - """Test statistics calculation.""" - - def test_empty_stats(self, timer): - """get_stats() on fresh timer returns empty dict.""" - assert timer.get_stats() == {} - - def test_stats_count_and_avg(self, timer): - """get_stats() should compute count and avg_ms correctly.""" - # Inject known timings for deterministic tests - timer._timings["hook_a"] = [100.0, 200.0, 300.0] - stats = timer.get_stats() - assert stats["hook_a"]["count"] == 3 - assert stats["hook_a"]["avg_ms"] == pytest.approx(200.0, abs=0.01) - - def test_stats_p95(self, timer): - """get_stats() should compute p95_ms correctly.""" - # 20 values: 10, 20, ..., 200 - timer._timings["hook_a"] = [float(i * 10) for i in range(1, 21)] - stats = timer.get_stats() - # p95 of 20 items: index 19 (0.95*20=19) -> value 200 - assert stats["hook_a"]["p95_ms"] == pytest.approx(200.0, abs=0.01) - - def test_stats_max(self, timer): - """get_stats() should compute max_ms correctly.""" - timer._timings["hook_a"] = [10.0, 50.0, 30.0, 999.0, 20.0] - stats = timer.get_stats() - assert stats["hook_a"]["max_ms"] == pytest.approx(999.0, abs=0.01) - - -class TestGetWarnings: - """Test timeout warning detection.""" - - def test_no_warnings_when_fast(self, timer): - """Hooks well under timeout should produce no warnings.""" - timer._timings["fast_hook"] = [100.0, 200.0, 300.0] - warnings = timer.get_warnings(timeout_ms=10000) - assert warnings == [] - - def test_warning_at_80_percent(self, timer): - """Hook using >=80% of timeout should produce a warning.""" - # 80% of 10000 = 8000 - timer._timings["slow_hook"] = [8000.0] - warnings = timer.get_warnings(timeout_ms=10000) - assert len(warnings) == 1 - assert "slow_hook" in warnings[0] - - def test_warning_with_custom_timeout(self, timer): - """get_warnings() should respect custom timeout_ms.""" - timer._timings["hook_a"] = [500.0] - # 80% of 600 = 480, so 500 >= 480 -> warning - warnings = timer.get_warnings(timeout_ms=600) - assert len(warnings) == 1 - - def test_no_warning_just_below_threshold(self, timer): - """Hook just below 80% threshold should NOT warn.""" - timer._timings["hook_a"] = [7999.0] - warnings = timer.get_warnings(timeout_ms=10000) - assert warnings == [] diff --git a/packages/claude-code-plugin/tests/test_stats.py b/packages/claude-code-plugin/tests/test_stats.py index 0cb8b26a..2d9d6871 100644 --- a/packages/claude-code-plugin/tests/test_stats.py +++ b/packages/claude-code-plugin/tests/test_stats.py @@ -250,6 +250,106 @@ def test_format_summary_no_timing_when_empty(self, stats): assert "⏱" not in summary +class TestConcurrentFlush: + """Regression test for race condition in flush() (#1493). + + Multiple processes calling record_tool_call() + flush() against the + same session/data_dir must not lose updates. + """ + + @staticmethod + def _worker(data_dir: str, session_id: str, n: int) -> None: + """Worker that records n tool calls and flushes each one.""" + import sys as _sys + _lib = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "hooks", "lib") + if _lib not in _sys.path: + _sys.path.insert(0, _lib) + from stats import SessionStats as _SS + s = _SS(session_id=session_id, data_dir=data_dir, flush_interval=10) + for _ in range(n): + s.record_tool_call("Bash") + s.flush() + + def test_concurrent_flush_no_lost_updates(self, data_dir): + """8 processes x 100 calls = 800 total. Final disk count must be 800.""" + import multiprocessing as mp + + session_id = "race-test" + num_workers = 8 + calls_per_worker = 100 + expected = num_workers * calls_per_worker + + # Seed the stats file + SessionStats(session_id=session_id, data_dir=data_dir) + + procs = [ + mp.Process(target=self._worker, args=(data_dir, session_id, calls_per_worker)) + for _ in range(num_workers) + ] + for p in procs: + p.start() + for p in procs: + p.join() + + s = SessionStats(session_id=session_id, data_dir=data_dir) + on_disk = s._locked_read() + assert on_disk["tool_count"] == expected, ( + f"Expected {expected}, got {on_disk['tool_count']} — lost updates detected" + ) + assert on_disk["tool_names"]["Bash"] == expected + + +class TestHookTimingWiring: + """Regression test: hook scripts record timing via time_hook context manager (#1494).""" + + def test_time_hook_records_timing(self, data_dir): + """time_hook should record elapsed_ms to disk hook_timings.""" + from hook_runtime import time_hook + + session_id = "timing-test" + SessionStats(session_id=session_id, data_dir=data_dir) + + with time_hook("PostToolUse", session_id=session_id, data_dir=data_dir): + time.sleep(0.01) + + s = SessionStats(session_id=session_id, data_dir=data_dir) + data = s._locked_read() + assert "PostToolUse" in data["hook_timings"] + assert len(data["hook_timings"]["PostToolUse"]) >= 1 + assert data["hook_timings"]["PostToolUse"][0] >= 5 # at least ~5ms + + def test_time_hook_records_even_on_inner_exception(self, data_dir): + """Timing is recorded even when inner code raises.""" + from hook_runtime import time_hook + + session_id = "exc-test" + SessionStats(session_id=session_id, data_dir=data_dir) + + with pytest.raises(ValueError): + with time_hook("PostToolUse", session_id=session_id, data_dir=data_dir): + raise ValueError("intentional error inside hook") + + s = SessionStats(session_id=session_id, data_dir=data_dir) + data = s._locked_read() + assert "PostToolUse" in data["hook_timings"] + assert len(data["hook_timings"]["PostToolUse"]) >= 1 + + def test_time_hook_records_multiple(self, data_dir): + """Multiple time_hook invocations accumulate timings.""" + from hook_runtime import time_hook + + session_id = "multi-timing" + SessionStats(session_id=session_id, data_dir=data_dir) + + for _ in range(3): + with time_hook("PreToolUse", session_id=session_id, data_dir=data_dir): + pass + + s = SessionStats(session_id=session_id, data_dir=data_dir) + data = s._locked_read() + assert len(data["hook_timings"]["PreToolUse"]) == 3 + + class TestCleanup: def test_cleanup_stale_removes_old_files(self, data_dir): # Create a stale file