diff --git a/.env.template b/.env.template index b5047c9..d738a91 100644 --- a/.env.template +++ b/.env.template @@ -27,3 +27,21 @@ GEMINI_API_KEY= # Optional Uvicorn bind settings used by start.sh / make run-* HOST=0.0.0.0 PORT=5000 + +# --------------------------------------------------------------------------- +# Continuous graph updates (webhook / poll-watcher) +# --------------------------------------------------------------------------- + +# Shared secret used for GitHub HMAC verification or GitLab's +# X-Gitlab-Token verification. Leave empty to require +# Authorization: Bearer on /api/webhook instead. +WEBHOOK_SECRET= + +# Name of the branch to track for automatic incremental updates. +# Only push events targeting this branch trigger a graph update. +TRACKED_BRANCH=main + +# Seconds between automatic poll-watcher checks (0 = disable poll-watcher). +# The poll-watcher runs as a background task and checks every tracked +# repository for new commits on TRACKED_BRANCH. +POLL_INTERVAL=60 diff --git a/README.md b/README.md index 533b32e..48dbcaa 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,9 @@ cp .env.template .env | `MODEL_NAME` | LiteLLM model used by `/api/chat` | No | `gemini/gemini-flash-lite-latest` | | `HOST` | Optional Uvicorn bind host for `start.sh`/`make run-*` | No | `0.0.0.0` or `127.0.0.1` depending on command | | `PORT` | Optional Uvicorn bind port for `start.sh`/`make run-*` | No | `5000` | +| `WEBHOOK_SECRET` | Shared secret for GitHub HMAC or GitLab `X-Gitlab-Token` verification on `/api/webhook` | No | empty | +| `TRACKED_BRANCH` | Branch watched by the webhook and poll-watcher | No | `main` | +| `POLL_INTERVAL` | Seconds between background poll checks (`0` disables polling) | No | `60` | The chat endpoint also needs the provider credential expected by your chosen `MODEL_NAME`. The default model is Gemini, so set `GEMINI_API_KEY` unless you switch to a different LiteLLM provider/model. @@ -97,6 +100,8 @@ The chat endpoint also needs the provider credential expected by your chosen `MO - If `SECRET_TOKEN` is unset, the current implementation accepts requests without an `Authorization` header. - Setting `CODE_GRAPH_PUBLIC=1` makes the read-only endpoints public even when `SECRET_TOKEN` is configured. +Continuous graph updates can be triggered either by posting a GitHub/GitLab push payload to `/api/webhook` or by enabling the background poll-watcher with `POLL_INTERVAL > 0`. When `WEBHOOK_SECRET` is unset, `/api/webhook` falls back to the same bearer-token auth used by the other mutating endpoints. + ### 3. Install dependencies ```bash @@ -241,6 +246,7 @@ A C analyzer exists in the source tree, but it is commented out and is not curre | POST | `/api/analyze_folder` | Analyze a local source folder | | POST | `/api/analyze_repo` | Clone and analyze a git repository | | POST | `/api/switch_commit` | Switch the indexed repository to a specific commit | +| POST | `/api/webhook` | Receive a GitHub/GitLab push event and apply an incremental graph update | ## License diff --git a/api/analyzers/analyzer.py b/api/analyzers/analyzer.py index 64d4900..a02ff15 100644 --- a/api/analyzers/analyzer.py +++ b/api/analyzers/analyzer.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -7,6 +8,14 @@ from abc import ABC, abstractmethod from multilspy import SyncLanguageServer +from ..graph import Graph + + +@dataclass(frozen=True) +class ResolvedEntityRef: + id: int + + class AbstractAnalyzer(ABC): def __init__(self, language: Language) -> None: self.language = language @@ -56,8 +65,69 @@ def resolve(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: P try: locations = lsp.request_definition(str(file_path), node.start_point.row, node.start_point.column) return [(files[Path(self.resolve_path(location['absolutePath'], path))], files[Path(self.resolve_path(location['absolutePath'], path))].tree.root_node.descendant_for_point_range(Point(location['range']['start']['line'], location['range']['start']['character']), Point(location['range']['end']['line'], location['range']['end']['character']))) for location in locations if location and Path(self.resolve_path(location['absolutePath'], path)) in files] - except Exception as e: + except Exception: return [] + + def resolve_entities( + self, + files: dict[Path, File], + lsp: SyncLanguageServer, + file_path: Path, + path: Path, + node: Node, + graph: Graph, + parent_types: list[str], + graph_labels: list[str], + reject_parent_types: Optional[set[str]] = None, + ) -> list[Entity | ResolvedEntityRef]: + try: + locations = lsp.request_definition( + str(file_path), node.start_point.row, node.start_point.column + ) + except Exception: + return [] + + resolved_entities: list[Entity | ResolvedEntityRef] = [] + for location in locations: + if not location or 'absolutePath' not in location: + continue + + resolved_path = Path(self.resolve_path(location['absolutePath'], path)) + if resolved_path in files: + file = files[resolved_path] + resolved_node = file.tree.root_node.descendant_for_point_range( + Point( + location['range']['start']['line'], + location['range']['start']['character'], + ), + Point( + location['range']['end']['line'], + location['range']['end']['character'], + ), + ) + entity_node = self.find_parent(resolved_node, parent_types) + if entity_node is None: + continue + if reject_parent_types and entity_node.type in reject_parent_types: + continue + + entity = file.entities.get(entity_node) + if entity is not None: + resolved_entities.append(entity) + continue + + if graph is None: + continue + + graph_entity = graph.get_entity_at_position( + str(resolved_path), + location['range']['start']['line'], + graph_labels, + ) + if graph_entity is not None: + resolved_entities.append(ResolvedEntityRef(graph_entity.id)) + + return resolved_entities @abstractmethod def add_dependencies(self, path: Path, files: list[Path]): @@ -133,7 +203,7 @@ def add_symbols(self, entity: Entity) -> None: pass @abstractmethod - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph: Graph, key: str, symbol: Node) -> list[Entity | ResolvedEntityRef]: """ Resolve a symbol to an entity. @@ -148,4 +218,3 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ """ pass - diff --git a/api/analyzers/csharp/analyzer.py b/api/analyzers/csharp/analyzer.py index 74c3906..61e8c8a 100644 --- a/api/analyzers/csharp/analyzer.py +++ b/api/analyzers/csharp/analyzer.py @@ -105,34 +105,41 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'], + ['Class', 'Interface', 'Enum', 'Struct'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: if node.type == 'invocation_expression': func_node = node.child_by_field_name('function') if func_node and func_node.type == 'member_access_expression': func_node = func_node.child_by_field_name('name') if func_node: node = func_node - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']) - if method_dec and method_dec.type in ['class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration']: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'], + ['Method', 'Constructor'], + {'class_declaration', 'interface_declaration', 'enum_declaration', 'struct_declaration'}, + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 5269d69..1ce80f8 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -1,7 +1,8 @@ import os from pathlib import Path import subprocess -from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -102,28 +103,35 @@ def resolve_path(self, file_path: str, path: Path) -> str: return f"{path}/temp_deps/{args[1]}/{targs}/{args[-1]}" return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_declaration', 'interface_declaration', 'enum_declaration']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_declaration', 'interface_declaration', 'enum_declaration'], + ['Class', 'Interface', 'Enum'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] - for file, resolved_node in self.resolve(files, lsp, file_path, path, node.child_by_field_name('name')): - method_dec = self.find_parent(resolved_node, ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration']) - if method_dec and method_dec.type in ['class_declaration', 'interface_declaration', 'enum_declaration']: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: + return self.resolve_entities( + files, + lsp, + file_path, + path, + node.child_by_field_name('name'), + graph, + ['method_declaration', 'constructor_declaration', 'class_declaration', 'interface_declaration', 'enum_declaration'], + ['Method', 'Constructor'], + {'class_declaration', 'interface_declaration', 'enum_declaration'}, + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 7a99120..a63d0b4 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -4,7 +4,8 @@ from pathlib import Path import tomllib -from ...entities import * +from ...entities.entity import Entity +from ...entities.file import File from typing import Optional from ..analyzer import AbstractAnalyzer @@ -91,34 +92,40 @@ def is_dependency(self, file_path: str) -> bool: def resolve_path(self, file_path: str, path: Path) -> str: return file_path - def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, node: Node) -> list[Entity]: - res = [] + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path, graph, node: Node) -> list[Entity]: if node.type == 'attribute': node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - type_dec = self.find_parent(resolved_node, ['class_definition']) - if type_dec in file.entities: - res.append(file.entities[type_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['class_definition'], + ['Class'], + ) - def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: - res = [] + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, node: Node) -> list[Entity]: if node.type == 'call': node = node.child_by_field_name('function') if node.type == 'attribute': node = node.child_by_field_name('attribute') - for file, resolved_node in self.resolve(files, lsp, file_path, path, node): - method_dec = self.find_parent(resolved_node, ['function_definition', 'class_definition']) - if not method_dec: - continue - if method_dec in file.entities: - res.append(file.entities[method_dec]) - return res + return self.resolve_entities( + files, + lsp, + file_path, + path, + node, + graph, + ['function_definition', 'class_definition'], + ['Function', 'Class'], + ) - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, graph, key: str, symbol: Node) -> list[Entity]: if key in ["base_class", "parameters", "return_type"]: - return self.resolve_type(files, lsp, file_path, path, symbol) + return self.resolve_type(files, lsp, file_path, path, graph, symbol) elif key in ["call"]: - return self.resolve_method(files, lsp, file_path, path, symbol) + return self.resolve_method(files, lsp, file_path, path, graph, symbol) else: raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 4186f35..73e2cc1 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -149,7 +149,7 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: file = self.files[file_path] logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}') for _, entity in file.entities.items(): - entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol)) + entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, graph, key, symbol)) for key, symbols in entity.symbols.items(): for symbol in symbols: if len(symbol.resolved_symbol) == 0: @@ -220,4 +220,3 @@ def analyze_local_repository(self, path: str, ignore: Optional[list[str]] = None graph.set_graph_commit(current_commit.short_id) return graph - diff --git a/api/git_utils/__init__.py b/api/git_utils/__init__.py index 4fd3af9..ab12dce 100644 --- a/api/git_utils/__init__.py +++ b/api/git_utils/__init__.py @@ -1 +1,28 @@ -from .git_utils import * +from . import git_utils as git_utils +from .git_utils import ( + GitRepoName as GitRepoName, + build_commit_graph as build_commit_graph, + classify_changes as classify_changes, + is_ignored as is_ignored, + switch_commit as switch_commit, +) +from .git_graph import GitGraph as GitGraph +from .incremental_update import ( + fetch_remote as fetch_remote, + get_remote_head as get_remote_head, + incremental_update as incremental_update, + repo_local_path as repo_local_path, +) + +__all__ = [ + "GitRepoName", + "GitGraph", + "build_commit_graph", + "classify_changes", + "fetch_remote", + "get_remote_head", + "incremental_update", + "is_ignored", + "repo_local_path", + "switch_commit", +] diff --git a/api/git_utils/incremental_update.py b/api/git_utils/incremental_update.py new file mode 100644 index 0000000..e651403 --- /dev/null +++ b/api/git_utils/incremental_update.py @@ -0,0 +1,303 @@ +"""Incremental graph update engine. + +Given a before/after commit SHA pair, computes the file-level diff, +applies additions/deletions/modifications to the FalkorDB code graph, +and bookmarks the new commit SHA in Redis so the system can resume +correctly after restarts or failures. +""" + +from contextlib import contextmanager +import logging +import os +import subprocess +from pathlib import Path +from typing import Optional + +from pygit2.enums import CheckoutStrategy +from pygit2.repository import Repository + +from ..analyzers.source_analyzer import SourceAnalyzer +from ..graph import Graph +from ..info import get_redis_connection, set_repo_commit +from .git_utils import classify_changes + +logger = logging.getLogger(__name__) +REPO_UPDATE_LOCK_TIMEOUT = int(os.getenv("REPO_UPDATE_LOCK_TIMEOUT", "300")) +REPO_UPDATE_LOCK_WAIT = int(os.getenv("REPO_UPDATE_LOCK_WAIT", "30")) + + +def repo_local_path(repo_name: str) -> Path: + """Return the local filesystem path for a cloned repository. + + Respects the ``REPOSITORIES_DIR`` environment variable; falls back to + ``/repositories/`` which matches the convention used by + :func:`api.project._clone_source`. + """ + base = os.getenv("REPOSITORIES_DIR", str(Path.cwd() / "repositories")) + return Path(base) / repo_name + + +def fetch_remote(repo_path: Path) -> None: + """Fetch latest changes from the remote *origin*. + + Args: + repo_path: Absolute path to the local git clone. + + Raises: + subprocess.CalledProcessError: If the git fetch command fails. + """ + logger.info("Fetching remote changes for %s", repo_path) + subprocess.run( + ["git", "fetch", "origin"], + cwd=str(repo_path), + check=True, + capture_output=True, + text=True, + ) + + +def get_remote_head(repo_path: Path, branch: str) -> Optional[str]: + """Return the full SHA of the remote tracking branch HEAD. + + Args: + repo_path: Absolute path to the local git clone. + branch: Branch name (e.g. ``"main"``). + + Returns: + The 40-character commit SHA, or ``None`` if the branch does not exist + on the remote or the command fails. + """ + try: + result = subprocess.run( + ["git", "rev-parse", f"origin/{branch}"], + cwd=str(repo_path), + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() or None + except subprocess.CalledProcessError: + logger.warning("Could not resolve origin/%s in %s", branch, repo_path) + return None + + +@contextmanager +def repo_update_lock(repo_name: str): + """Acquire a repo-scoped distributed lock for graph mutations.""" + redis_connection = get_redis_connection() + lock = redis_connection.lock( + f"code-graph:repo-update:{repo_name}", + timeout=REPO_UPDATE_LOCK_TIMEOUT, + blocking_timeout=REPO_UPDATE_LOCK_WAIT, + thread_local=False, + ) + + logger.debug("Acquiring repo update lock for '%s'", repo_name) + if not lock.acquire(blocking=True): + raise TimeoutError(f"Timed out waiting for update lock for '{repo_name}'") + + try: + yield + finally: + if lock.owned(): + lock.release() + logger.debug("Released repo update lock for '%s'", repo_name) + + +def _resolve_commit(repo: Repository, sha: str): + return repo.revparse_single(sha) + + +def _is_ancestor(repo: Repository, ancestor_sha: str, descendant_sha: str) -> bool: + ancestor = _resolve_commit(repo, ancestor_sha) + descendant = _resolve_commit(repo, descendant_sha) + return ancestor.id == descendant.id or repo.merge_base(ancestor.id, descendant.id) == ancestor.id + + +def can_incrementally_update( + repo_path: Path, + from_sha: str, + to_sha: str, + before_sha: Optional[str] = None, +) -> bool: + """Return True when the stored bookmark can be safely advanced incrementally.""" + try: + repo = Repository(str(repo_path)) + if before_sha is not None and not _is_ancestor(repo, from_sha, before_sha): + return False + + anchor_sha = before_sha or from_sha + return _is_ancestor(repo, anchor_sha, to_sha) + except Exception as exc: + logger.warning( + "Cannot validate incremental update range for '%s' -> '%s' (before=%s): %s", + from_sha, + to_sha, + before_sha, + exc, + ) + return False + + +def _dedupe_paths(paths: list[Path]) -> list[Path]: + seen: set[Path] = set() + deduped: list[Path] = [] + for path in paths: + if path in seen: + continue + seen.add(path) + deduped.append(path) + return deduped + + +def _collect_transitive_dependents(g: Graph, changed_files: list[Path]) -> list[Path]: + seen = set(changed_files) + dependents: list[Path] = [] + frontier = _dedupe_paths(changed_files) + + while frontier: + direct_dependents = g.get_direct_dependent_files(frontier) + next_frontier: list[Path] = [] + for dependent in direct_dependents: + if dependent in seen: + continue + seen.add(dependent) + dependents.append(dependent) + next_frontier.append(dependent) + frontier = next_frontier + + return dependents + + +def incremental_update( + repo_name: str, + from_sha: str, + to_sha: str, + ignore: Optional[list[str]] = None, +) -> dict: + """Incrementally update the code graph from ``from_sha`` to ``to_sha``. + + Deleted files are removed from the graph. Modified files are removed + and then re-analysed. Added files are analysed and inserted. The + commit bookmark stored in Redis is updated to the short ID of ``to_sha`` + on success, matching the convention used by the rest of the system. + + This function is idempotent: if ``from_sha == to_sha`` it returns + immediately without touching the graph or the bookmark. + + Args: + repo_name: Graph name in FalkorDB (and repository directory name). + from_sha: Commit SHA the graph is currently at (old state). + Accepts both abbreviated and full 40-char SHAs. + to_sha: Target commit SHA to advance the graph to (new state). + Accepts both abbreviated and full 40-char SHAs. + ignore: Optional list of path prefixes to skip during analysis. + + Returns: + A :class:`dict` with keys: + + * ``files_added`` – number of newly added source files processed. + * ``files_modified`` – number of modified source files re-processed. + * ``files_deleted`` – number of deleted source files removed. + * ``commit`` – the short SHA bookmark now stored in Redis. + + Raises: + ValueError: If the local repository clone cannot be found, or if + either SHA cannot be resolved. + """ + if ignore is None: + ignore = [] + + if from_sha == to_sha: + logger.info( + "incremental_update: from_sha == to_sha (%s); nothing to do", from_sha + ) + return { + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha, + } + + repo_path = repo_local_path(repo_name) + if not repo_path.exists(): + raise ValueError(f"Local repository not found at {repo_path}") + + logger.info( + "Incremental update for '%s': %s -> %s", repo_name, from_sha, to_sha + ) + + repo = Repository(str(repo_path)) + + # Resolve commits – accepts both abbreviated and full SHAs + try: + from_commit = repo.revparse_single(from_sha) + except Exception as exc: + raise ValueError(f"Cannot resolve from_sha '{from_sha}': {exc}") from exc + try: + to_commit = repo.revparse_single(to_sha) + except Exception as exc: + raise ValueError(f"Cannot resolve to_sha '{to_sha}': {exc}") from exc + + # Compute the file-level diff between the two commits + analyzer = SourceAnalyzer() + supported_types = analyzer.supported_types() + diff = repo.diff(from_commit, to_commit) + added, deleted, modified = classify_changes(diff, repo, supported_types, ignore) + + logger.info( + "Diff for '%s': %d added, %d modified, %d deleted", + repo_name, + len(added), + len(modified), + len(deleted), + ) + + files_to_remove = _dedupe_paths(deleted + modified) + + with repo_update_lock(repo_name): + try: + # Checkout target commit so files on disk reflect to_sha + repo.checkout_tree(to_commit.tree, strategy=CheckoutStrategy.FORCE) + repo.set_head_detached(to_commit.id) + + # Apply graph changes + g = Graph(repo_name) + dependent_files = _collect_transitive_dependents(g, files_to_remove) + + if dependent_files: + logger.info( + "Reprocessing %d dependent file(s) for '%s'", + len(dependent_files), + repo_name, + ) + + if files_to_remove: + logger.info("Removing %d file(s) from graph", len(files_to_remove)) + g.delete_files(files_to_remove) + + deleted_files = set(deleted) + files_to_add = [ + file_path + for file_path in _dedupe_paths(added + modified + dependent_files) + if file_path not in deleted_files + ] + if files_to_add: + logger.info("Inserting/updating %d file(s) in graph", len(files_to_add)) + analyzer.analyze_files(files_to_add, repo_path, g) + + # Persist the new commit bookmark using the short ID for consistency + # with the rest of the system (build_commit_graph, analyze_sources …) + new_commit_short = to_commit.short_id + set_repo_commit(repo_name, new_commit_short) + logger.info("Graph for '%s' updated to commit %s", repo_name, new_commit_short) + except Exception: + logger.exception("Incremental update failed for '%s'", repo_name) + raise + + return { + "files_added": len(added), + "files_modified": len(modified), + "files_deleted": len(deleted), + "commit": new_commit_short, + } diff --git a/api/graph.py b/api/graph.py index 085dfde..83f959d 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,6 +1,6 @@ import os import time -from .entities import * +from .entities import File, encode_edge, encode_node from typing import Optional from falkordb import FalkorDB, Path, Node, QueryResult from falkordb.asyncio import FalkorDB as AsyncFalkorDB @@ -32,6 +32,20 @@ def get_repos() -> list[str]: graphs = [g for g in graphs if not (g.endswith('_git') or g.endswith('_schema'))] return graphs + +def delete_graph_if_exists(name: str) -> bool: + """Delete *name* when it already exists in FalkorDB.""" + db = FalkorDB(host=os.getenv('FALKORDB_HOST', 'localhost'), + port=os.getenv('FALKORDB_PORT', 6379), + username=os.getenv('FALKORDB_USERNAME', None), + password=os.getenv('FALKORDB_PASSWORD', None)) + + if name not in db.list_graphs(): + return False + + db.select_graph(name).delete() + return True + class Graph(): """ Represents a connection to a graph database using FalkorDB. @@ -171,7 +185,7 @@ def _query(self, q: str, params: Optional[dict] = None) -> QueryResult: return result_set - def get_sub_graph(self, l: int) -> dict: + def get_sub_graph(self, limit: int) -> dict: q = """MATCH (src) OPTIONAL MATCH (src)-[e]->(dest) @@ -180,7 +194,7 @@ def get_sub_graph(self, l: int) -> dict: sub_graph = {'nodes': [], 'edges': [] } - result_set = self._query(q, {'limit': l}).result_set + result_set = self._query(q, {'limit': limit}).result_set for row in result_set: src = row[0] e = row[1] @@ -466,6 +480,44 @@ def get_file(self, path: str, name: str, ext: str) -> Optional[File]: return file + def get_entity_at_position(self, path: str, line: int, labels: Optional[list[str]] = None) -> Optional[Node]: + """Return the smallest entity spanning *line* within *path*.""" + label_filter = ":" + ":".join(labels) if labels else "" + q = f"""MATCH (e{label_filter}) + WHERE e.path = $path + AND e.src_start <= $line + AND e.src_end >= $line + RETURN e + ORDER BY (e.src_end - e.src_start) ASC + LIMIT 1""" + + res = self._query(q, {'path': path, 'line': line}).result_set + if len(res) == 0: + return None + + return res[0][0] + + def get_direct_dependent_files(self, files: list[Path]) -> list[Path]: + """Return files that directly depend on entities defined in *files*.""" + if len(files) == 0: + return [] + + q = """UNWIND $files AS file + MATCH (changed_file:File {path: file['path'], name: file['name'], ext: file['ext']}) + MATCH (changed_file)-[:DEFINES*]->(changed_entity) + MATCH (dependent_entity)-[:CALLS|EXTENDS|IMPLEMENTS|RETURNS|PARAMETERS]->(changed_entity) + MATCH (dependent_file:File)-[:DEFINES*]->(dependent_entity) + RETURN DISTINCT dependent_file.path, dependent_file.name, dependent_file.ext""" + + params = { + 'files': [ + {'path': str(file_path), 'name': file_path.name, 'ext': file_path.suffix} + for file_path in files + ] + } + result_set = self._query(q, params).result_set + return [Path(row[0]) for row in result_set] + # set file code coverage # if file coverage is 100% set every defined function coverage to 100% aswell def set_file_coverage(self, path: str, name: str, ext: str, coverage: float) -> None: @@ -478,7 +530,7 @@ def set_file_coverage(self, path: str, name: str, ext: str, coverage: float) -> params = {'path': path, 'name': name, 'ext': ext, 'coverage': coverage} - res = self._query(q, params) + self._query(q, params) def connect_entities(self, relation: str, src_id: int, dest_id: int, properties: dict = {}) -> None: """ @@ -768,4 +820,3 @@ async def stats(self) -> dict: async def close(self) -> None: await self.db.aclose() - diff --git a/api/index.py b/api/index.py index 38dfb61..c566f6d 100644 --- a/api/index.py +++ b/api/index.py @@ -1,19 +1,30 @@ """ Main API module for CodeGraph. """ +import hashlib +import hmac import os import asyncio +import contextlib import logging from pathlib import Path from dotenv import load_dotenv -from fastapi import Depends, FastAPI, Header, HTTPException, Query +from fastapi import Depends, FastAPI, Header, HTTPException, Query, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from api.analyzers.source_analyzer import SourceAnalyzer from api.git_utils import git_utils from api.git_utils.git_graph import AsyncGitGraph -from api.graph import Graph, AsyncGraphQuery, async_get_repos -from api.info import async_get_repo_info +from api.git_utils.incremental_update import ( + can_incrementally_update, + fetch_remote, + get_remote_head, + incremental_update, + repo_local_path, + repo_update_lock, +) +from api.graph import Graph, AsyncGraphQuery, async_get_repos, delete_graph_if_exists +from api.info import async_get_repo_info, get_repo_commit from api.llm import ask from api.project import Project @@ -98,7 +109,304 @@ class SwitchCommitRequest(BaseModel): str(Path(__file__).resolve().parent.parent)) ).resolve() -app = FastAPI() +# --------------------------------------------------------------------------- +# Webhook / poll-watcher configuration +# --------------------------------------------------------------------------- + +# HMAC-SHA256 secret shared with GitHub/GitLab. Leave unset to skip +# signature validation (not recommended for production). +WEBHOOK_SECRET: str = os.getenv("WEBHOOK_SECRET", "") + +# Branch whose pushes trigger incremental graph updates. +TRACKED_BRANCH: str = os.getenv("TRACKED_BRANCH", "main") + +# Seconds between automatic poll checks (0 = disabled). +POLL_INTERVAL: int = int(os.getenv("POLL_INTERVAL", "60")) + +# --------------------------------------------------------------------------- +# Webhook helpers +# --------------------------------------------------------------------------- + +def _urls_match(stored_url: str, incoming_url: str) -> bool: + """Return True when two repository URLs refer to the same repo. + + Normalises both URLs by stripping a trailing ``.git`` suffix and + converting to lower-case so that, for example, + ``https://github.com/Org/Repo`` and + ``https://github.com/org/repo.git`` are treated as identical. + """ + def _normalise(u: str) -> str: + return u.rstrip("/").removesuffix(".git").lower() + + return _normalise(stored_url) == _normalise(incoming_url) + + +async def _find_repo_by_url(url: str) -> str | None: + """Return the graph name for a repository that matches *url*, or ``None``.""" + repos = await async_get_repos() + for repo_name in repos: + info = await async_get_repo_info(repo_name) + if info and _urls_match(info.get("repo_url", ""), url): + return repo_name + return None + + +def _webhook_auth_mode() -> str: + if WEBHOOK_SECRET: + return "shared-secret" + if SECRET_TOKEN: + return "token" + return "disabled" + + +def _log_webhook_auth_mode() -> None: + mode = _webhook_auth_mode() + if mode == "shared-secret": + logger.info( + "Webhook auth mode: shared secret (GitHub HMAC or GitLab X-Gitlab-Token)" + ) + elif mode == "token": + logger.info("Webhook auth mode: Authorization bearer token fallback") + else: + logger.warning( + "Webhook auth is not configured; /api/webhook will reject requests until " + "WEBHOOK_SECRET or SECRET_TOKEN is set" + ) + + +def _authenticate_webhook_request(request: Request, body: bytes) -> None: + """Authenticate a webhook request using the configured webhook auth mode.""" + if WEBHOOK_SECRET: + github_signature = request.headers.get("X-Hub-Signature-256") + gitlab_token = request.headers.get("X-Gitlab-Token") + gitlab_event = request.headers.get("X-Gitlab-Event") + gitlab_signature = request.headers.get("X-Gitlab-Signature") + + if github_signature: + mac = hmac.new(WEBHOOK_SECRET.encode(), body, hashlib.sha256) + expected_signature = "sha256=" + mac.hexdigest() + if not hmac.compare_digest(github_signature, expected_signature): + raise HTTPException(status_code=401, detail="Invalid GitHub webhook signature") + return + + if gitlab_token or gitlab_event or gitlab_signature: + if not gitlab_token: + raise HTTPException( + status_code=401, + detail="GitLab webhooks must include X-Gitlab-Token", + ) + if not hmac.compare_digest(gitlab_token, WEBHOOK_SECRET): + raise HTTPException(status_code=401, detail="Invalid GitLab webhook token") + return + + raise HTTPException( + status_code=401, + detail="Missing supported webhook authentication header", + ) + + if not SECRET_TOKEN: + logger.error( + "Webhook auth misconfigured: set WEBHOOK_SECRET or SECRET_TOKEN before " + "accepting webhook updates" + ) + raise HTTPException( + status_code=503, + detail="Webhook authentication is not configured", + ) + + token_required(request.headers.get("Authorization")) + + +def _extract_repo_url(payload: dict) -> str: + repository = payload.get("repository", {}) + project = payload.get("project", {}) + return ( + repository.get("clone_url") + or repository.get("git_http_url") + or project.get("git_http_url") + or "" + ) + + +def _full_reindex_repository( + repo_name: str, + repo_path: Path, + repo_url: str = "", + ignore: list[str] | None = None, + reason: str = "", +) -> dict: + if ignore is None: + ignore = [] + + logger.warning( + "Falling back to a full reindex for '%s'%s", + repo_name, + f": {reason}" if reason else "", + ) + + with repo_update_lock(repo_name): + delete_graph_if_exists(repo_name) + delete_graph_if_exists(git_utils.GitRepoName(repo_name)) + + if repo_path.exists(): + proj = Project.from_local_repository(repo_path) + elif repo_url: + proj = Project.from_git_repository(repo_url) + else: + raise ValueError( + f"Cannot reindex '{repo_name}': local clone is missing and no repo URL is available" + ) + + proj.analyze_sources(ignore) + proj.process_git_history(ignore) + + return { + "mode": "full_reindex", + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": get_repo_commit(repo_name), + } + + +def _sync_repo_graph( + repo_name: str, + repo_path: Path, + target_sha: str, + *, + before_sha: str | None = None, + repo_url: str = "", + ignore: list[str] | None = None, +) -> dict: + if ignore is None: + ignore = [] + + if not repo_path.exists(): + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + "local clone missing", + ) + + stored_sha = get_repo_commit(repo_name) + if not stored_sha: + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + "missing stored commit bookmark", + ) + + if not can_incrementally_update(repo_path, stored_sha, target_sha, before_sha): + return _full_reindex_repository( + repo_name, + repo_path, + repo_url, + ignore, + ( + f"stored bookmark {stored_sha} does not align with " + f"before={before_sha or ''} and target={target_sha}" + ), + ) + + return incremental_update(repo_name, stored_sha, target_sha, ignore) + +# --------------------------------------------------------------------------- +# Background poll-watcher helpers (synchronous, run in thread-pool executor) +# --------------------------------------------------------------------------- + +def _poll_repo(repo_name: str) -> None: + """Fetch remote and apply incremental updates for *repo_name* if behind. + + This function is intentionally synchronous so it can be safely offloaded + to ``asyncio``'s default ``ThreadPoolExecutor``. + """ + path = repo_local_path(repo_name) + if not path.exists(): + logger.debug("Poll: local clone not found for '%s', skipping", repo_name) + return + + try: + fetch_remote(path) + except Exception as exc: + logger.warning("Poll: git fetch failed for '%s': %s", repo_name, exc) + return + + remote_head = get_remote_head(path, TRACKED_BRANCH) + if not remote_head: + return + + current_sha = get_repo_commit(repo_name) + if current_sha: + # Handle comparison between short (7-char) and full (40-char) SHAs: a short + # stored SHA is a valid prefix of a full remote SHA for the same commit. + # We only apply prefix matching when the stored SHA is shorter. + if len(current_sha) < len(remote_head): + up_to_date = remote_head.startswith(current_sha) + elif len(current_sha) > len(remote_head): + up_to_date = current_sha.startswith(remote_head) + else: + up_to_date = current_sha == remote_head + if up_to_date: + logger.debug("Poll: '%s' is up-to-date at %s", repo_name, current_sha) + return + else: + logger.warning("Poll: '%s' has no stored bookmark; forcing a full reindex", repo_name) + + logger.info( + "Poll: new commits detected for '%s' (%s -> %s), updating …", + repo_name, current_sha, remote_head, + ) + try: + result = _sync_repo_graph(repo_name, path, remote_head) + logger.info("Poll: '%s' updated — %s", repo_name, result) + except Exception as exc: + logger.exception( + "Poll: incremental update failed for '%s': %s", repo_name, exc + ) + + +async def _poll_all_repos() -> None: + """Check every indexed repository for new commits on the tracked branch.""" + repos = await async_get_repos() + loop = asyncio.get_running_loop() + for repo_name in repos: + await loop.run_in_executor(None, _poll_repo, repo_name) + + +async def _poll_loop() -> None: + """Continuously poll all repositories at the configured interval.""" + logger.info( + "Poll-watcher started (interval=%ds, branch='%s')", + POLL_INTERVAL, TRACKED_BRANCH, + ) + while True: + try: + await _poll_all_repos() + except Exception as exc: + logger.exception("Poll loop error: %s", exc) + await asyncio.sleep(POLL_INTERVAL) + +# --------------------------------------------------------------------------- +# Application lifespan (starts/stops the background poll task) +# --------------------------------------------------------------------------- + +@contextlib.asynccontextmanager +async def _lifespan(application: FastAPI): + _log_webhook_auth_mode() + poll_task = None + if POLL_INTERVAL > 0: + poll_task = asyncio.create_task(_poll_loop()) + yield + if poll_task is not None: + poll_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await poll_task + +app = FastAPI(lifespan=_lifespan) # --------------------------------------------------------------------------- # API routes @@ -290,6 +598,88 @@ async def list_commits(data: RepoRequest, _=Depends(public_or_auth)): await git_graph.close() return {"status": "success", "commits": commits} + +@app.post('/api/webhook') +async def webhook(request: Request): + """Receive a GitHub/GitLab push event and trigger an incremental graph update. + + When ``WEBHOOK_SECRET`` is set the endpoint validates GitHub's + ``X-Hub-Signature-256`` HMAC signature or GitLab's ``X-Gitlab-Token``. + Without ``WEBHOOK_SECRET`` the endpoint falls back to the standard bearer + token auth used by the other mutating routes. + + Only pushes to the branch configured via ``TRACKED_BRANCH`` (default + ``main``) trigger an update; pushes to other branches are acknowledged + with a ``200 ignored`` response so that GitHub does not retry them. + + The repository is identified by matching the payload's repository URL + (`repository.clone_url`, `repository.git_http_url`, or `project.git_http_url`) + against the URLs stored for already-indexed repositories. + """ + body = await request.body() + _authenticate_webhook_request(request, body) + + try: + payload = await request.json() + except Exception: + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + ref = payload.get("ref", "") + before = payload.get("before", "") + after = payload.get("after", "") + repo_url = _extract_repo_url(payload) + + # Only process pushes to the configured tracked branch + expected_ref = f"refs/heads/{TRACKED_BRANCH}" + if ref != expected_ref: + logger.debug("Webhook: ignoring push to '%s' (tracking '%s')", ref, expected_ref) + return {"status": "ignored", "reason": f"Branch not tracked: {ref}"} + + if not before or not after or not repo_url: + raise HTTPException( + status_code=400, + detail=( + "Payload missing required fields: ref, before, after, and a repository URL " + "(repository.clone_url, repository.git_http_url, or project.git_http_url)" + ), + ) + + # Resolve the repository name from the stored index + repo_name = await _find_repo_by_url(repo_url) + if repo_name is None: + logger.warning("Webhook: received push for unknown repo '%s'", repo_url) + return JSONResponse( + {"status": "error", "detail": "Repository not indexed"}, + status_code=404, + ) + + logger.info( + "Webhook: updating '%s' from %s to %s", repo_name, before[:8], after[:8] + ) + + def _update() -> dict: + path = repo_local_path(repo_name) + if path.exists(): + fetch_remote(path) + return _sync_repo_graph( + repo_name, + path, + after, + before_sha=before, + repo_url=repo_url, + ) + + loop = asyncio.get_running_loop() + try: + result = await loop.run_in_executor(None, _update) + except Exception as exc: + logger.exception( + "Webhook: incremental update failed for '%s': %s", repo_name, exc + ) + return JSONResponse({"status": "error", "detail": str(exc)}, status_code=500) + + return {"status": "success", **result} + # --------------------------------------------------------------------------- # SPA static file serving (must come after API routes) # --------------------------------------------------------------------------- diff --git a/tests/test_incremental_update.py b/tests/test_incremental_update.py new file mode 100644 index 0000000..c0ffe0d --- /dev/null +++ b/tests/test_incremental_update.py @@ -0,0 +1,151 @@ +from contextlib import contextmanager +import importlib + +from api.analyzers.python.analyzer import PythonAnalyzer + + +class _DummyLSP: + def __init__(self, locations): + self._locations = locations + + def request_definition(self, *_args, **_kwargs): + return self._locations + + +class _DummyGraphNode: + def __init__(self, node_id: int): + self.id = node_id + + +class _DummyGraphLookup: + def __init__(self, node_id: int): + self.node_id = node_id + self.calls = [] + + def get_entity_at_position(self, path, line, labels): + self.calls.append((path, line, labels)) + return _DummyGraphNode(self.node_id) + + +def test_python_resolve_symbol_uses_graph_fallback(tmp_path): + """Cross-file resolution falls back to graph lookups for unchanged files.""" + analyzer = PythonAnalyzer() + caller = tmp_path / "caller.py" + target = tmp_path / "target.py" + caller.write_text("foo()\n") + target.write_text("def foo():\n pass\n") + + tree = analyzer.parser.parse(caller.read_bytes()) + call_node = analyzer._captures("(call) @call", tree.root_node)["call"][0] + graph = _DummyGraphLookup(42) + lsp = _DummyLSP( + [ + { + "absolutePath": str(target), + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 1, "character": 0}, + }, + } + ] + ) + + resolved = analyzer.resolve_symbol({}, lsp, caller, tmp_path, graph, "call", call_node) + + assert [entity.id for entity in resolved] == [42] + assert graph.calls == [(str(target), 0, ["Function", "Class"])] + + +def test_incremental_update_reprocesses_dependents_under_repo_lock(monkeypatch, tmp_path): + """Incremental updates expand transitive dependents and hold the repo lock.""" + incremental_update_module = importlib.import_module("api.git_utils.incremental_update") + repo_path = tmp_path / "repo" + repo_path.mkdir() + operations = [] + + class _FakeCommit: + def __init__(self, sha): + self.id = sha + self.short_id = sha[:7] + self.tree = object() + + class _FakeRepo: + def revparse_single(self, sha): + return _FakeCommit(sha) + + def diff(self, _from_commit, _to_commit): + return object() + + def checkout_tree(self, _tree, strategy=None): + operations.append(("checkout", strategy)) + + def set_head_detached(self, commit_id): + operations.append(("detach", commit_id)) + + class _FakeAnalyzer: + def supported_types(self): + return [".py"] + + def analyze_files(self, files, path, graph): + operations.append(("analyze", [file.name for file in files], path, graph)) + + class _FakeGraph: + def __init__(self, name): + self.name = name + + def get_direct_dependent_files(self, files): + names = tuple(file.name for file in files) + operations.append(("dependents", names)) + if names == ("deleted.py", "modified.py"): + return [repo_path / "caller.py"] + if names == ("caller.py",): + return [repo_path / "transitive.py"] + return [] + + def delete_files(self, files): + operations.append(("delete", [file.name for file in files])) + + @contextmanager + def _fake_repo_lock(repo_name): + operations.append(("lock-enter", repo_name)) + try: + yield + finally: + operations.append(("lock-exit", repo_name)) + + monkeypatch.setattr(incremental_update_module, "repo_local_path", lambda _name: repo_path) + monkeypatch.setattr(incremental_update_module, "Repository", lambda _path: _FakeRepo()) + monkeypatch.setattr(incremental_update_module, "SourceAnalyzer", _FakeAnalyzer) + monkeypatch.setattr(incremental_update_module, "Graph", _FakeGraph) + monkeypatch.setattr( + incremental_update_module, + "classify_changes", + lambda _diff, _repo, _supported, _ignore: ( + [repo_path / "added.py"], + [repo_path / "deleted.py"], + [repo_path / "modified.py"], + ), + ) + monkeypatch.setattr( + incremental_update_module, + "set_repo_commit", + lambda repo_name, commit: operations.append(("bookmark", repo_name, commit)), + ) + monkeypatch.setattr(incremental_update_module, "repo_update_lock", _fake_repo_lock) + + result = incremental_update_module.incremental_update("repo", "aaaa111", "bbbb222") + + assert result == { + "files_added": 1, + "files_modified": 1, + "files_deleted": 1, + "commit": "bbbb222", + } + assert ("delete", ["deleted.py", "modified.py"]) in operations + analyze_call = next(op for op in operations if op[0] == "analyze") + assert analyze_call[1] == ["added.py", "modified.py", "caller.py", "transitive.py"] + assert analyze_call[2] == repo_path + assert operations[0] == ("lock-enter", "repo") + assert operations[-1] == ("lock-exit", "repo") + assert operations.index(("lock-enter", "repo")) < operations.index(("checkout", incremental_update_module.CheckoutStrategy.FORCE)) + assert operations.index(("bookmark", "repo", "bbbb222")) < operations.index(("lock-exit", "repo")) diff --git a/tests/test_webhook.py b/tests/test_webhook.py new file mode 100644 index 0000000..7ab0092 --- /dev/null +++ b/tests/test_webhook.py @@ -0,0 +1,449 @@ +"""Unit tests for the webhook endpoint and incremental update helpers. + +These tests use ``monkeypatch`` to mock out external collaborators (FalkorDB, +Redis, git) so they run without a live database or network connection. +""" + +import hashlib +import hmac +import importlib +import json + +import pytest +from starlette.testclient import TestClient + +import api.index + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# A full Git SHA-1 hash is 40 hexadecimal characters. +_FULL_SHA_BEFORE = "aaaa1111" * 5 # 40-char SHA simulating the "before" commit +_FULL_SHA_AFTER = "bbbb2222" * 5 # 40-char SHA simulating the "after" commit + + +class _FakePath: + """Minimal Path-like object for use with monkeypatch.""" + + def __init__(self, *, exists: bool): + self._exists = exists + + def exists(self) -> bool: + return self._exists + + +def _make_push_payload( + ref: str = "refs/heads/main", + before: str = _FULL_SHA_BEFORE, + after: str = _FULL_SHA_AFTER, + clone_url: str = "https://github.com/example/myrepo.git", +) -> dict: + return { + "ref": ref, + "before": before, + "after": after, + "repository": {"clone_url": clone_url}, + } + + +def _sign(body: bytes, secret: str) -> str: + mac = hmac.new(secret.encode(), body, hashlib.sha256) + return "sha256=" + mac.hexdigest() + + +# --------------------------------------------------------------------------- +# _urls_match +# --------------------------------------------------------------------------- + +def test_urls_match_identical(): + assert api.index._urls_match( + "https://github.com/org/repo.git", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_git_suffix(): + assert api.index._urls_match( + "https://github.com/org/repo", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_case_insensitive(): + assert api.index._urls_match( + "https://github.com/Org/Repo.git", + "https://github.com/org/repo.git", + ) + + +def test_urls_match_trailing_slash(): + assert api.index._urls_match( + "https://github.com/org/repo/", + "https://github.com/org/repo.git", + ) + + +def test_urls_no_match_different_repo(): + assert not api.index._urls_match( + "https://github.com/org/repo-a.git", + "https://github.com/org/repo-b.git", + ) + + +# --------------------------------------------------------------------------- +# Webhook endpoint – bearer token fallback mode +# --------------------------------------------------------------------------- + +@pytest.fixture() +def client_token_auth(monkeypatch): + """Test client with bearer-token webhook auth and no poll-watcher.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "") + monkeypatch.setattr(api.index, "SECRET_TOKEN", "apitoken") + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +@pytest.fixture() +def client_misconfigured(monkeypatch): + """Test client with webhook auth disabled entirely.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "") + monkeypatch.setattr(api.index, "SECRET_TOKEN", None) + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +def test_webhook_ignored_wrong_branch(client_token_auth, monkeypatch): + """Pushes to non-tracked branches return 200 with status='ignored'.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload(ref="refs/heads/feature/x") + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ignored" + + +def test_webhook_unknown_repo(client_token_auth, monkeypatch): + """Webhook for a repo URL that is not indexed returns 404.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + # No repos indexed → _find_repo_by_url returns None + async def _fake_get_repos(): + return [] + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + + payload = _make_push_payload() + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) + assert resp.status_code == 404 + + +def test_webhook_success(client_token_auth, monkeypatch): + """Valid push to tracked branch triggers incremental_update and returns stats.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://github.com/example/myrepo.git"} + + update_calls = [] + + def _fake_sync(repo_name, path, to_sha, before_sha=None, repo_url="", ignore=None): + update_calls.append((repo_name, before_sha, to_sha, repo_url)) + return { + "files_added": 1, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha[:7], + } + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "_sync_repo_graph", _fake_sync) + # Skip git fetch (no real clone) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = _make_push_payload() + resp = client_token_auth.post( + "/api/webhook", + json=payload, + headers={"Authorization": "Bearer apitoken"}, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "success" + assert data["files_added"] == 1 + assert len(update_calls) == 1 + assert update_calls[0] == ( + "myrepo", + _FULL_SHA_BEFORE, + _FULL_SHA_AFTER, + "https://github.com/example/myrepo.git", + ) + + +def test_webhook_requires_bearer_token_when_secret_missing(client_token_auth, monkeypatch): + """Bearer token auth protects the webhook when WEBHOOK_SECRET is unset.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_token_auth.post("/api/webhook", json=payload) + assert resp.status_code == 401 + + +def test_webhook_rejected_when_no_auth_is_configured(client_misconfigured, monkeypatch): + """The webhook returns 503 when neither WEBHOOK_SECRET nor SECRET_TOKEN is set.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_misconfigured.post("/api/webhook", json=payload) + assert resp.status_code == 503 + + +# --------------------------------------------------------------------------- +# Webhook endpoint – HMAC-SHA256 signature validation +# --------------------------------------------------------------------------- + +@pytest.fixture() +def client_secured(monkeypatch): + """Test client with WEBHOOK_SECRET='mysecret' and poll disabled.""" + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "mysecret") + monkeypatch.setattr(api.index, "POLL_INTERVAL", 0) + return TestClient(api.index.app, raise_server_exceptions=False) + + +def test_webhook_missing_signature_rejected(client_secured, monkeypatch): + """Requests without X-Hub-Signature-256 header are rejected with 401.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_secured.post("/api/webhook", json=payload) + assert resp.status_code == 401 + + +def test_webhook_wrong_signature_rejected(client_secured, monkeypatch): + """Requests with an incorrect signature are rejected with 401.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + body = json.dumps(payload).encode() + bad_sig = _sign(body, "wrongsecret") + resp = client_secured.post( + "/api/webhook", + content=body, + headers={"Content-Type": "application/json", "X-Hub-Signature-256": bad_sig}, + ) + assert resp.status_code == 401 + + +def test_webhook_valid_signature_accepted(client_secured, monkeypatch): + """Requests with a correct HMAC-SHA256 signature are accepted.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + monkeypatch.setattr(api.index, "WEBHOOK_SECRET", "mysecret") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://github.com/example/myrepo.git"} + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "_sync_repo_graph", lambda *a, **kw: { + "files_added": 0, "files_modified": 0, "files_deleted": 0, "commit": "abc1234", + }) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = _make_push_payload() + body = json.dumps(payload).encode() + sig = _sign(body, "mysecret") + + resp = client_secured.post( + "/api/webhook", + content=body, + headers={"Content-Type": "application/json", "X-Hub-Signature-256": sig}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + +def test_gitlab_webhook_token_accepted(client_secured, monkeypatch): + """GitLab webhooks authenticate via X-Gitlab-Token and git_http_url payloads.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + + async def _fake_get_repos(): + return ["myrepo"] + + async def _fake_get_repo_info(repo_name): + return {"repo_url": "https://gitlab.com/example/myrepo.git"} + + monkeypatch.setattr(api.index, "async_get_repos", _fake_get_repos) + monkeypatch.setattr(api.index, "async_get_repo_info", _fake_get_repo_info) + monkeypatch.setattr(api.index, "_sync_repo_graph", lambda *a, **kw: { + "files_added": 0, "files_modified": 0, "files_deleted": 0, "commit": "abc1234", + }) + monkeypatch.setattr(api.index, "fetch_remote", lambda path: None) + monkeypatch.setattr(api.index, "repo_local_path", lambda name: _FakePath(exists=False)) + + payload = { + "ref": "refs/heads/main", + "before": _FULL_SHA_BEFORE, + "after": _FULL_SHA_AFTER, + "repository": {"git_http_url": "https://gitlab.com/example/myrepo.git"}, + } + resp = client_secured.post( + "/api/webhook", + json=payload, + headers={"X-Gitlab-Token": "mysecret", "X-Gitlab-Event": "Push Hook"}, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + +def test_gitlab_webhook_missing_token_rejected(client_secured, monkeypatch): + """GitLab requests without X-Gitlab-Token are rejected.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + payload = _make_push_payload() + resp = client_secured.post( + "/api/webhook", + json=payload, + headers={"X-Gitlab-Event": "Push Hook"}, + ) + assert resp.status_code == 401 + + +def test_webhook_invalid_json(client_token_auth, monkeypatch): + """Non-JSON bodies are rejected with 400.""" + monkeypatch.setattr(api.index, "TRACKED_BRANCH", "main") + resp = client_token_auth.post( + "/api/webhook", + content=b"not-json", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer apitoken", + }, + ) + assert resp.status_code == 400 + + +def test_sync_repo_graph_uses_stored_bookmark(monkeypatch, tmp_path): + """Incremental sync uses the stored bookmark instead of payload.before.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + calls = [] + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "stored123") + monkeypatch.setattr(api.index, "can_incrementally_update", lambda *args, **kwargs: True) + monkeypatch.setattr( + api.index, + "incremental_update", + lambda repo_name, from_sha, to_sha, ignore=None: calls.append( + (repo_name, from_sha, to_sha, ignore) + ) or { + "files_added": 0, + "files_modified": 0, + "files_deleted": 0, + "commit": to_sha[:7], + }, + ) + + api.index._sync_repo_graph( + "myrepo", + repo_path, + _FULL_SHA_AFTER, + before_sha=_FULL_SHA_BEFORE, + ) + + assert calls == [("myrepo", "stored123", _FULL_SHA_AFTER, [])] + + +def test_sync_repo_graph_full_reindexes_without_bookmark(monkeypatch, tmp_path): + """Missing bookmarks fall back to a full reindex instead of partial diffing.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: None) + monkeypatch.setattr( + api.index, + "_full_reindex_repository", + lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + ) + + result = api.index._sync_repo_graph("myrepo", repo_path, _FULL_SHA_AFTER) + + assert result["mode"] == "full_reindex" + + +def test_sync_repo_graph_full_reindexes_on_history_gap(monkeypatch, tmp_path): + """History gaps or force-pushes fall back to a full reindex.""" + repo_path = tmp_path / "repo" + repo_path.mkdir() + + monkeypatch.setattr(api.index, "get_repo_commit", lambda name: "stored123") + monkeypatch.setattr(api.index, "can_incrementally_update", lambda *args, **kwargs: False) + monkeypatch.setattr( + api.index, + "_full_reindex_repository", + lambda *args, **kwargs: {"mode": "full_reindex", "commit": "abc1234"}, + ) + + result = api.index._sync_repo_graph( + "myrepo", + repo_path, + _FULL_SHA_AFTER, + before_sha=_FULL_SHA_BEFORE, + ) + + assert result["mode"] == "full_reindex" + + +# --------------------------------------------------------------------------- +# incremental_update – unit tests (no live DB/git) +# --------------------------------------------------------------------------- + +def test_incremental_update_idempotent(monkeypatch, tmp_path): + """Calling incremental_update with the same SHA twice is a no-op.""" + incremental_update_module = importlib.import_module("api.git_utils.incremental_update") + _iu = incremental_update_module.incremental_update + + # Patch set_repo_commit to detect unexpected writes + writes = [] + monkeypatch.setattr( + incremental_update_module, + "set_repo_commit", + lambda *a: writes.append(a), + ) + + sha = "abc1234" + result = _iu("some-repo", sha, sha) + + assert result["files_added"] == 0 + assert result["files_modified"] == 0 + assert result["files_deleted"] == 0 + assert result["commit"] == sha + assert writes == [], "set_repo_commit must not be called for no-op update" + + +def test_incremental_update_missing_repo(monkeypatch, tmp_path): + """incremental_update raises ValueError when local clone does not exist.""" + incremental_update_module = importlib.import_module("api.git_utils.incremental_update") + _iu = incremental_update_module.incremental_update + + monkeypatch.setattr( + incremental_update_module, + "repo_local_path", + lambda name: tmp_path / "nonexistent", + ) + + with pytest.raises(ValueError, match="Local repository not found"): + _iu("some-repo", "aaa1111", "bbb2222")