From 45657034075dce177b03c63ba926731b1a85145c Mon Sep 17 00:00:00 2001 From: bussyjd Date: Sun, 29 Mar 2026 19:59:50 +0400 Subject: [PATCH] =?UTF-8?q?feat(inference):=20WIP=20inference=20lifecycle?= =?UTF-8?q?=20prototype=20=E2=80=94=20discover,=20validate,=20register,=20?= =?UTF-8?q?serve?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements issue #300 prototype with Python modules: - registry.py: SQLite-backed model registry with state machine - hardware.py: GPU/system auto-detection, optimal ngl calculation - discover.py: x/LocalLLaMA signal sourcing via x-cli - validate.py: llama-bench + ToolCall-15 evaluation pipeline - serve.py: systemd service management with hot-swap + rollback - api.py: FastAPI routes for all lifecycle endpoints - tests: unit tests for registry, hardware, discovery, validation 1,942 lines total. All new files under internal/inference/lifecycle/. No existing Go files modified. Refs: #300 --- internal/inference/lifecycle/__init__.py | 0 internal/inference/lifecycle/api.py | 272 ++++++++++++++ internal/inference/lifecycle/discover.py | 218 +++++++++++ internal/inference/lifecycle/hardware.py | 188 ++++++++++ internal/inference/lifecycle/registry.py | 253 +++++++++++++ internal/inference/lifecycle/requirements.txt | 2 + internal/inference/lifecycle/serve.py | 316 ++++++++++++++++ internal/inference/lifecycle/validate.py | 347 +++++++++++++++++ tests/test_inference_lifecycle.py | 348 ++++++++++++++++++ 9 files changed, 1944 insertions(+) create mode 100644 internal/inference/lifecycle/__init__.py create mode 100644 internal/inference/lifecycle/api.py create mode 100644 internal/inference/lifecycle/discover.py create mode 100644 internal/inference/lifecycle/hardware.py create mode 100644 internal/inference/lifecycle/registry.py create mode 100644 internal/inference/lifecycle/requirements.txt create mode 100644 internal/inference/lifecycle/serve.py create mode 100644 internal/inference/lifecycle/validate.py create mode 100644 tests/test_inference_lifecycle.py diff --git a/internal/inference/lifecycle/__init__.py b/internal/inference/lifecycle/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/internal/inference/lifecycle/api.py b/internal/inference/lifecycle/api.py new file mode 100644 index 00000000..06aa63fc --- /dev/null +++ b/internal/inference/lifecycle/api.py @@ -0,0 +1,272 @@ +"""FastAPI routes for inference lifecycle management (Issue #300).""" + +import logging +import uuid +from dataclasses import asdict +from typing import Dict, List, Optional + +from fastapi import FastAPI, HTTPException, BackgroundTasks +from pydantic import BaseModel + +from .registry import Registry, ModelStatus +from .hardware import get_hardware_profile +from .discover import search_localllama, parse_candidates, filter_candidates +from .validate import validate_model, ValidationResult +from .serve import hot_swap, start_serving, stop_serving, get_serving_status + +logger = logging.getLogger(__name__) + +app = FastAPI( + title="Obol Inference Lifecycle API", + description="Model discovery, validation, and serving management", + version="0.1.0", +) + +# Module-level registry instance (initialized on startup) +_registry: Optional[Registry] = None +_validation_jobs: Dict[str, dict] = {} + + +def get_registry() -> Registry: + """Get or create the global registry instance.""" + global _registry + if _registry is None: + _registry = Registry() + return _registry + + +# --- Request/Response Models --- + +class DiscoverRequest(BaseModel): + max_results: int = 20 + max_size_gb: Optional[float] = None + + +class ValidateRequest(BaseModel): + url: str + eval_suite: str = "toolcall-15" + quant: Optional[str] = None + name: Optional[str] = None + + +class ServeRequest(BaseModel): + model_id: str + rollback_on_failure: bool = True + port: int = 8080 + + +class ModelResponse(BaseModel): + id: str + name: str + status: str + quant: Optional[str] = None + size_gb: Optional[float] = None + toolcall15_score: Optional[float] = None + tok_s_gen: Optional[float] = None + tok_s_prompt: Optional[float] = None + source_url: Optional[str] = None + signal_score: Optional[float] = None + + +# --- Lifecycle Events --- + +@app.on_event("startup") +async def startup(): + """Initialize registry on startup.""" + global _registry + _registry = Registry() + logger.info("Inference lifecycle API started") + + +@app.on_event("shutdown") +async def shutdown(): + """Clean up on shutdown.""" + if _registry: + _registry.close() + + +# --- API Endpoints --- + +@app.post("/api/v1/inference/discover") +async def discover_models(request: DiscoverRequest) -> Dict: + """Trigger model discovery via social signal scraping. + + Searches x/LocalLLaMA for new model announcements and GGUF releases. + Returns ranked candidates by signal score. + """ + try: + raw_json = search_localllama(max_results=request.max_results) + candidates = parse_candidates(raw_json) + + if request.max_size_gb: + candidates = filter_candidates(candidates, request.max_size_gb) + + # Register discovered candidates + registry = get_registry() + registered = [] + for c in candidates: + record = registry.add_model( + name=c.name, source_url=c.hf_url, + quant=c.quant, size_gb=c.size_gb_est, + signal_score=c.signal_score, + ) + registered.append({"id": record.id, "name": c.name, + "hf_url": c.hf_url, "quant": c.quant, + "size_gb_est": c.size_gb_est, + "signal_score": c.signal_score}) + + return {"candidates": registered, "total": len(registered)} + except Exception as e: + logger.error("Discovery failed: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + + +def _run_validation_job(job_id: str, url: str, name: Optional[str], + quant: Optional[str]) -> None: + """Background validation job runner.""" + from .discover import DiscoveryCandidate + from datetime import datetime + + _validation_jobs[job_id]["status"] = "running" + try: + candidate = DiscoveryCandidate( + name=name or url.split("/")[-1], + hf_url=url, quant=quant, size_gb_est=None, + signal_score=0.0, source_tweet_id="manual", + discovered_at=datetime.utcnow().isoformat(), + ) + hw = get_hardware_profile() + registry = get_registry() + result = validate_model(candidate, hw, registry) + _validation_jobs[job_id]["status"] = "completed" + _validation_jobs[job_id]["result"] = { + "model_id": result.model_id, + "passed": result.passed, + "toolcall15_score": result.toolcall15_score, + "tok_s_gen": result.tok_s_gen, + "tok_s_prompt": result.tok_s_prompt, + "error": result.error, + } + except Exception as e: + _validation_jobs[job_id]["status"] = "failed" + _validation_jobs[job_id]["error"] = str(e) + logger.error("Validation job %s failed: %s", job_id, e) + + +@app.post("/api/v1/inference/validate") +async def validate_model_endpoint(request: ValidateRequest, + background_tasks: BackgroundTasks) -> Dict: + """Start model validation as a background job. + + Downloads the model, runs benchmarks, and evaluates tool-calling ability. + Returns a job_id to track progress. + """ + job_id = str(uuid.uuid4())[:12] + _validation_jobs[job_id] = {"status": "queued", "url": request.url} + background_tasks.add_task( + _run_validation_job, job_id, request.url, request.name, request.quant + ) + return {"job_id": job_id, "status": "queued"} + + +@app.get("/api/v1/inference/validate/{job_id}") +async def get_validation_status(job_id: str) -> Dict: + """Check status of a validation job.""" + if job_id not in _validation_jobs: + raise HTTPException(status_code=404, detail=f"Job {job_id} not found") + return _validation_jobs[job_id] + + +@app.get("/api/v1/inference/models") +async def list_models(status: Optional[str] = None) -> Dict: + """List all models in the registry, optionally filtered by status.""" + registry = get_registry() + status_filter = ModelStatus(status) if status else None + models = registry.list_models(status_filter) + return { + "models": [ + ModelResponse( + id=m.id, name=m.name, status=m.status.value, + quant=m.quant, size_gb=m.size_gb, + toolcall15_score=m.toolcall15_score, + tok_s_gen=m.tok_s_gen, tok_s_prompt=m.tok_s_prompt, + source_url=m.source_url, signal_score=m.signal_score, + ).dict() + for m in models + ], + "total": len(models), + } + + +@app.put("/api/v1/inference/models/{model_id}/promote") +async def promote_model(model_id: str) -> Dict: + """Promote a model to serving status.""" + registry = get_registry() + try: + model = registry.promote(model_id) + return {"id": model.id, "name": model.name, "status": model.status.value} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.put("/api/v1/inference/models/{model_id}/retire") +async def retire_model(model_id: str) -> Dict: + """Retire a model from serving.""" + registry = get_registry() + try: + model = registry.retire(model_id) + return {"id": model.id, "name": model.name, "status": model.status.value} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.get("/api/v1/inference/status") +async def inference_status() -> Dict: + """Get current serving model and health status.""" + registry = get_registry() + status = get_serving_status(registry) + return { + "active": status.active, + "healthy": status.healthy, + "model_id": status.model_id, + "model_name": status.model_name, + "port": status.port, + "pid": status.pid, + } + + +@app.post("/api/v1/inference/serve") +async def serve_model(request: ServeRequest) -> Dict: + """Hot-swap to serve a specific model.""" + registry = get_registry() + try: + status = hot_swap( + request.model_id, registry, + rollback_on_failure=request.rollback_on_failure, + port=request.port, + ) + return { + "active": status.active, + "healthy": status.healthy, + "model_id": status.model_id, + "model_name": status.model_name, + "port": status.port, + } + except (ValueError, RuntimeError) as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/v1/inference/hardware") +async def hardware_profile() -> Dict: + """Get current hardware profile.""" + hw = get_hardware_profile() + return { + "gpu_name": hw.gpu_name, + "gpu_backend": hw.gpu_backend, + "vram_gb": hw.vram_gb, + "ram_gb": hw.ram_gb, + "disk_free_gb": hw.disk_free_gb, + "cpu_cores": hw.cpu_cores, + "os": hw.os_name, + "arch": hw.arch, + } diff --git a/internal/inference/lifecycle/discover.py b/internal/inference/lifecycle/discover.py new file mode 100644 index 00000000..366ac2d1 --- /dev/null +++ b/internal/inference/lifecycle/discover.py @@ -0,0 +1,218 @@ +"""Discovery module — finds new GGUF models via x-cli social signal scraping.""" + +import json +import logging +import os +import re +import subprocess +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + +# x-cli binary expected on PATH (includes ~/.local/bin) +X_CLI_BIN = "x-cli" +HERMES_ENV = Path.home() / ".hermes" / ".env" + +# Known quant patterns +QUANT_PATTERNS = re.compile( + r"(Q[2-8]_[KMS](?:_[SML])?|IQ[1-4]_[A-Z]+|F16|F32|BF16)", re.IGNORECASE +) +# HuggingFace URL pattern +HF_URL_PATTERN = re.compile( + r"(?:https?://)?huggingface\.co/([a-zA-Z0-9_-]+/[a-zA-Z0-9._-]+)" +) +# Model size from filename like 7B, 13B, 70B +SIZE_PATTERN = re.compile(r"(\d+(?:\.\d+)?)\s*[Bb](?:illion)?") + + +@dataclass +class DiscoveryCandidate: + """A model candidate discovered from social signals.""" + name: str + hf_url: Optional[str] + quant: Optional[str] + size_gb_est: Optional[float] + signal_score: float + source_tweet_id: str + discovered_at: str + + +def _load_hermes_env() -> Dict[str, str]: + """Load environment variables from ~/.hermes/.env for x-cli auth.""" + env = os.environ.copy() + env["PATH"] = f"{Path.home() / '.local' / 'bin'}:{env.get('PATH', '')}" + if HERMES_ENV.exists(): + try: + with open(HERMES_ENV) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + env[key.strip()] = value.strip().strip("\"'") + logger.debug("Loaded %s env vars from %s", len(env), HERMES_ENV) + except IOError as e: + logger.warning("Could not read hermes env: %s", e) + return env + + +def search_localllama(max_results: int = 20) -> str: + """Search x/LocalLLaMA community for GGUF/model/benchmark tweets. + + Uses x-cli to search recent posts from the LocalLLaMA community + for model announcements, GGUF releases, and benchmark discussions. + + Returns: + Raw JSON string from x-cli search results. + """ + env = _load_hermes_env() + queries = [ + "LocalLLaMA GGUF new model", + "LocalLLaMA benchmark release quantized", + "GGUF huggingface release llama", + ] + all_results = [] + for query in queries: + try: + result = subprocess.run( + [X_CLI_BIN, "search", "--query", query, + "--max-results", str(max_results // len(queries)), + "--format", "json"], + capture_output=True, text=True, timeout=30, env=env + ) + if result.returncode == 0 and result.stdout.strip(): + try: + tweets = json.loads(result.stdout) + if isinstance(tweets, list): + all_results.extend(tweets) + elif isinstance(tweets, dict) and "data" in tweets: + all_results.extend(tweets["data"]) + except json.JSONDecodeError: + logger.warning("Failed to parse x-cli output for query: %s", query) + else: + logger.warning("x-cli search failed for '%s': %s", query, result.stderr[:200]) + except (subprocess.TimeoutExpired, FileNotFoundError) as e: + logger.error("x-cli execution error: %s", e) + + logger.info("Discovered %d raw tweets from x-cli search", len(all_results)) + return json.dumps(all_results) + + +def parse_candidates(tweets_json: str) -> List[DiscoveryCandidate]: + """Extract model candidates from tweet JSON data. + + Parses tweet text for HuggingFace URLs, quant types, model sizes, + and model names. Deduplicates by HF URL. + + Args: + tweets_json: JSON string of tweet objects. + + Returns: + List of DiscoveryCandidate objects. + """ + try: + tweets = json.loads(tweets_json) + except json.JSONDecodeError: + logger.error("Invalid JSON in tweets data") + return [] + + if not isinstance(tweets, list): + tweets = [tweets] + + candidates = [] + seen_urls = set() + + for tweet in tweets: + text = tweet.get("text", "") or tweet.get("content", "") + tweet_id = str(tweet.get("id", tweet.get("tweet_id", "unknown"))) + + # Extract HF URL + hf_match = HF_URL_PATTERN.search(text) + hf_url = f"https://huggingface.co/{hf_match.group(1)}" if hf_match else None + + if hf_url and hf_url in seen_urls: + continue + if hf_url: + seen_urls.add(hf_url) + + # Extract quant + quant_match = QUANT_PATTERNS.search(text) + quant = quant_match.group(1).upper() if quant_match else None + + # Extract model size estimate + size_match = SIZE_PATTERN.search(text) + size_gb_est = None + if size_match: + param_b = float(size_match.group(1)) + # Rough estimate: Q4 ~ 0.5 GB/B params, Q8 ~ 1 GB/B + multiplier = 0.5 if quant and "4" in quant else 0.75 + size_gb_est = round(param_b * multiplier, 1) + + # Extract model name (first capitalized multi-word near "model" or from HF URL) + name = "Unknown Model" + if hf_match: + name = hf_match.group(1).split("/")[-1] + else: + name_match = re.search(r"([A-Z][a-zA-Z0-9]*(?:[-_][A-Za-z0-9]+){1,5})", text) + if name_match: + name = name_match.group(1) + + signal = score_signal(tweet) + + candidates.append(DiscoveryCandidate( + name=name, + hf_url=hf_url, + quant=quant, + size_gb_est=size_gb_est, + signal_score=signal, + source_tweet_id=tweet_id, + discovered_at=datetime.utcnow().isoformat(), + )) + + # Sort by signal score descending + candidates.sort(key=lambda c: c.signal_score, reverse=True) + logger.info("Parsed %d candidates from %d tweets", len(candidates), len(tweets)) + return candidates + + +def score_signal(tweet: dict) -> float: + """Compute a signal score for a tweet based on engagement metrics. + + Formula: likes * 1.0 + bookmarks * 2.0 + retweets * 1.5 + + Args: + tweet: Tweet dict with engagement metric fields. + + Returns: + Weighted engagement score. + """ + likes = float(tweet.get("likes", 0) or tweet.get("like_count", 0) or 0) + bookmarks = float(tweet.get("bookmarks", 0) or tweet.get("bookmark_count", 0) or 0) + retweets = float(tweet.get("retweets", 0) or tweet.get("retweet_count", 0) or 0) + return likes * 1.0 + bookmarks * 2.0 + retweets * 1.5 + + +def filter_candidates(candidates: List[DiscoveryCandidate], + max_size_gb: float) -> List[DiscoveryCandidate]: + """Filter candidates that exceed the hardware's capacity. + + Removes models with estimated size larger than max_size_gb. + Models with unknown size are kept (benefit of the doubt). + + Args: + candidates: List of discovery candidates. + max_size_gb: Maximum model size in GB. + + Returns: + Filtered list of candidates. + """ + filtered = [ + c for c in candidates + if c.size_gb_est is None or c.size_gb_est <= max_size_gb + ] + removed = len(candidates) - len(filtered) + if removed: + logger.info("Filtered out %d candidates exceeding %.1f GB", removed, max_size_gb) + return filtered diff --git a/internal/inference/lifecycle/hardware.py b/internal/inference/lifecycle/hardware.py new file mode 100644 index 00000000..291baa50 --- /dev/null +++ b/internal/inference/lifecycle/hardware.py @@ -0,0 +1,188 @@ +"""Hardware profiler for inference lifecycle — detects GPU, RAM, disk, and computes optimal settings.""" + +import logging +import os +import platform +import re +import shutil +import subprocess +from dataclasses import dataclass +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + + +@dataclass +class HardwareProfile: + """Complete hardware profile for inference planning.""" + gpu_name: str + gpu_backend: str # cuda, rocm, metal, cpu + vram_gb: float + ram_gb: float + disk_free_gb: float + cpu_cores: int + os_name: str + arch: str + + +def detect_gpu() -> Tuple[str, str, float]: + """Detect GPU name, backend, and VRAM in GB. + + Tries nvidia-smi first, then rocm-smi, then falls back to CPU. + Returns (gpu_name, backend, vram_gb). + """ + # Try NVIDIA + if shutil.which("nvidia-smi"): + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader,nounits"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + line = result.stdout.strip().split("\n")[0] + parts = [p.strip() for p in line.split(",")] + gpu_name = parts[0] + vram_mb = float(parts[1]) + vram_gb = round(vram_mb / 1024, 1) + logger.info("Detected NVIDIA GPU: %s with %.1f GB VRAM", gpu_name, vram_gb) + return gpu_name, "cuda", vram_gb + except (subprocess.TimeoutExpired, IndexError, ValueError) as e: + logger.warning("nvidia-smi failed: %s", e) + + # Try ROCm + if shutil.which("rocm-smi"): + try: + result = subprocess.run( + ["rocm-smi", "--showmeminfo", "vram", "--csv"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + # Parse ROCm CSV output for VRAM total + lines = result.stdout.strip().split("\n") + for line in lines[1:]: # skip header + if "total" in line.lower() or len(lines) == 2: + nums = re.findall(r"[\d.]+", line) + if nums: + vram_bytes = float(nums[-1]) + vram_gb = round(vram_bytes / (1024**3), 1) + break + else: + vram_gb = 0.0 + # Get GPU name + name_result = subprocess.run( + ["rocm-smi", "--showproductname"], + capture_output=True, text=True, timeout=10 + ) + gpu_name = "AMD GPU" + if name_result.returncode == 0: + for l in name_result.stdout.split("\n"): + if "card" in l.lower() or "gpu" in l.lower(): + gpu_name = l.strip().split(":")[-1].strip() or gpu_name + break + logger.info("Detected ROCm GPU: %s with %.1f GB VRAM", gpu_name, vram_gb) + return gpu_name, "rocm", vram_gb + except (subprocess.TimeoutExpired, ValueError) as e: + logger.warning("rocm-smi failed: %s", e) + + # macOS Metal + if platform.system() == "Darwin": + try: + result = subprocess.run( + ["system_profiler", "SPDisplaysDataType"], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + vram_match = re.search(r"VRAM.*?(\d+)\s*(MB|GB)", result.stdout, re.IGNORECASE) + name_match = re.search(r"Chipset Model:\s*(.+)", result.stdout) + gpu_name = name_match.group(1).strip() if name_match else "Apple GPU" + vram_gb = 0.0 + if vram_match: + val = float(vram_match.group(1)) + vram_gb = val if vram_match.group(2).upper() == "GB" else val / 1024 + logger.info("Detected Metal GPU: %s with %.1f GB VRAM", gpu_name, vram_gb) + return gpu_name, "metal", vram_gb + except (subprocess.TimeoutExpired, ValueError) as e: + logger.warning("system_profiler failed: %s", e) + + logger.info("No GPU detected, falling back to CPU") + return "CPU", "cpu", 0.0 + + +def detect_system() -> Tuple[float, float, int]: + """Detect system RAM (GB), disk free (GB), and CPU cores. + + Returns (ram_gb, disk_free_gb, cpu_cores). + """ + # RAM + ram_gb = 0.0 + try: + result = subprocess.run(["free", "-b"], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + for line in result.stdout.split("\n"): + if line.startswith("Mem:"): + parts = line.split() + ram_gb = round(float(parts[1]) / (1024**3), 1) + break + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback for non-Linux + try: + import psutil + ram_gb = round(psutil.virtual_memory().total / (1024**3), 1) + except ImportError: + ram_gb = 0.0 + + # Disk free + disk_free_gb = 0.0 + try: + stat = os.statvfs(os.path.expanduser("~")) + disk_free_gb = round(stat.f_bavail * stat.f_frsize / (1024**3), 1) + except OSError: + pass + + # CPU cores + cpu_cores = os.cpu_count() or 1 + + logger.info("System: %.1f GB RAM, %.1f GB disk free, %d CPU cores", + ram_gb, disk_free_gb, cpu_cores) + return ram_gb, disk_free_gb, cpu_cores + + +def compute_optimal_ngl(model_size_gb: float, vram_gb: float, + total_layers: int = 80) -> int: + """Estimate optimal number of GPU layers to offload. + + Heuristic: (vram_gb - 2.0) / model_size_gb * total_layers, capped at 999. + Reserves 2 GB VRAM for KV cache and OS overhead. + + Args: + model_size_gb: Model file size in GB (rough proxy for weight memory). + vram_gb: Available VRAM in GB. + total_layers: Estimated total layers in the model. + + Returns: + Number of layers to offload to GPU (0 if insufficient VRAM). + """ + if vram_gb <= 2.0 or model_size_gb <= 0: + return 0 + usable_vram = vram_gb - 2.0 + ngl = int((usable_vram / model_size_gb) * total_layers) + ngl = max(0, min(ngl, 999)) + logger.debug("compute_optimal_ngl: %.1f GB model, %.1f GB VRAM -> ngl=%d", + model_size_gb, vram_gb, ngl) + return ngl + + +def get_hardware_profile() -> HardwareProfile: + """Build and return a complete hardware profile.""" + gpu_name, gpu_backend, vram_gb = detect_gpu() + ram_gb, disk_free_gb, cpu_cores = detect_system() + return HardwareProfile( + gpu_name=gpu_name, + gpu_backend=gpu_backend, + vram_gb=vram_gb, + ram_gb=ram_gb, + disk_free_gb=disk_free_gb, + cpu_cores=cpu_cores, + os_name=platform.system(), + arch=platform.machine(), + ) diff --git a/internal/inference/lifecycle/registry.py b/internal/inference/lifecycle/registry.py new file mode 100644 index 00000000..df89e6bd --- /dev/null +++ b/internal/inference/lifecycle/registry.py @@ -0,0 +1,253 @@ +"""Model registry with SQLite backend for inference lifecycle management.""" + +import enum +import logging +import sqlite3 +import uuid +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class ModelStatus(enum.Enum): + """Lifecycle states for a model.""" + discovered = "discovered" + downloading = "downloading" + validating = "validating" + passed = "passed" + failed = "failed" + registered = "registered" + serving = "serving" + retired = "retired" + + +# Valid state transitions +VALID_TRANSITIONS = { + ModelStatus.discovered: {ModelStatus.downloading, ModelStatus.failed}, + ModelStatus.downloading: {ModelStatus.validating, ModelStatus.failed}, + ModelStatus.validating: {ModelStatus.passed, ModelStatus.failed}, + ModelStatus.passed: {ModelStatus.registered, ModelStatus.failed}, + ModelStatus.failed: {ModelStatus.discovered}, # allow retry + ModelStatus.registered: {ModelStatus.serving, ModelStatus.retired}, + ModelStatus.serving: {ModelStatus.retired}, + ModelStatus.retired: {ModelStatus.registered}, # allow re-register +} + + +@dataclass +class ModelRecord: + """Complete record for a tracked model.""" + id: str + name: str + gguf_path: Optional[str] = None + quant: Optional[str] = None + size_gb: Optional[float] = None + vram_required_gb: Optional[float] = None + toolcall15_score: Optional[float] = None + tok_s_gen: Optional[float] = None + tok_s_prompt: Optional[float] = None + source_url: Optional[str] = None + signal_score: Optional[float] = None + status: ModelStatus = ModelStatus.discovered + discovered_at: Optional[str] = None + validated_at: Optional[str] = None + registered_at: Optional[str] = None + serving_since: Optional[str] = None + + +class Registry: + """SQLite-backed model registry managing the full inference lifecycle.""" + + def __init__(self, db_path: str = "inference_lifecycle.db"): + self.db_path = db_path + self._conn: Optional[sqlite3.Connection] = None + self.init_db() + + def _get_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self.db_path) + self._conn.row_factory = sqlite3.Row + return self._conn + + def init_db(self) -> None: + """Create tables if they don't exist.""" + conn = self._get_conn() + conn.executescript(""" + CREATE TABLE IF NOT EXISTS models ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + gguf_path TEXT, + quant TEXT, + size_gb REAL, + vram_required_gb REAL, + toolcall15_score REAL, + tok_s_gen REAL, + tok_s_prompt REAL, + source_url TEXT, + signal_score REAL, + status TEXT NOT NULL DEFAULT 'discovered', + discovered_at TEXT, + validated_at TEXT, + registered_at TEXT, + serving_since TEXT + ); + CREATE TABLE IF NOT EXISTS benchmark_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_id TEXT NOT NULL, + eval_suite TEXT NOT NULL, + score REAL, + tok_s_gen REAL, + tok_s_prompt REAL, + timestamp TEXT NOT NULL, + FOREIGN KEY (model_id) REFERENCES models(id) + ); + CREATE INDEX IF NOT EXISTS idx_models_status ON models(status); + CREATE INDEX IF NOT EXISTS idx_bench_model ON benchmark_history(model_id); + """) + conn.commit() + logger.info("Registry database initialized at %s", self.db_path) + + def add_model(self, name: str, source_url: Optional[str] = None, + quant: Optional[str] = None, size_gb: Optional[float] = None, + signal_score: Optional[float] = None) -> ModelRecord: + """Add a newly discovered model to the registry.""" + model_id = str(uuid.uuid4())[:12] + now = datetime.utcnow().isoformat() + record = ModelRecord( + id=model_id, name=name, source_url=source_url, quant=quant, + size_gb=size_gb, signal_score=signal_score, + status=ModelStatus.discovered, discovered_at=now, + ) + conn = self._get_conn() + conn.execute( + """INSERT INTO models (id, name, gguf_path, quant, size_gb, vram_required_gb, + toolcall15_score, tok_s_gen, tok_s_prompt, source_url, signal_score, + status, discovered_at, validated_at, registered_at, serving_since) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (record.id, record.name, record.gguf_path, record.quant, record.size_gb, + record.vram_required_gb, record.toolcall15_score, record.tok_s_gen, + record.tok_s_prompt, record.source_url, record.signal_score, + record.status.value, record.discovered_at, record.validated_at, + record.registered_at, record.serving_since) + ) + conn.commit() + logger.info("Added model %s (%s) to registry", model_id, name) + return record + + def update_status(self, model_id: str, new_status: ModelStatus) -> ModelRecord: + """Transition a model to a new status with validation.""" + current = self.get_model(model_id) + if current is None: + raise ValueError(f"Model {model_id} not found") + if new_status not in VALID_TRANSITIONS.get(current.status, set()): + raise ValueError( + f"Invalid transition: {current.status.value} -> {new_status.value}" + ) + conn = self._get_conn() + now = datetime.utcnow().isoformat() + updates = {"status": new_status.value} + if new_status == ModelStatus.validating: + updates["validated_at"] = now + elif new_status == ModelStatus.registered: + updates["registered_at"] = now + elif new_status == ModelStatus.serving: + updates["serving_since"] = now + + set_clause = ", ".join(f"{k} = ?" for k in updates) + conn.execute(f"UPDATE models SET {set_clause} WHERE id = ?", + list(updates.values()) + [model_id]) + conn.commit() + logger.info("Model %s: %s -> %s", model_id, current.status.value, new_status.value) + return self.get_model(model_id) + + def update_benchmark(self, model_id: str, eval_suite: str, score: float, + tok_s_gen: float, tok_s_prompt: float) -> None: + """Record benchmark results for a model.""" + conn = self._get_conn() + now = datetime.utcnow().isoformat() + conn.execute( + """INSERT INTO benchmark_history (model_id, eval_suite, score, + tok_s_gen, tok_s_prompt, timestamp) VALUES (?, ?, ?, ?, ?, ?)""", + (model_id, eval_suite, score, tok_s_gen, tok_s_prompt, now) + ) + conn.execute( + """UPDATE models SET toolcall15_score = ?, tok_s_gen = ?, + tok_s_prompt = ? WHERE id = ?""", + (score, tok_s_gen, tok_s_prompt, model_id) + ) + conn.commit() + logger.info("Benchmark recorded for %s: score=%.1f gen=%.1f prompt=%.1f", + model_id, score, tok_s_gen, tok_s_prompt) + + def get_model(self, model_id: str) -> Optional[ModelRecord]: + """Retrieve a single model by ID.""" + conn = self._get_conn() + row = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,)).fetchone() + if row is None: + return None + return self._row_to_record(row) + + def list_models(self, status_filter: Optional[ModelStatus] = None) -> List[ModelRecord]: + """List models, optionally filtered by status.""" + conn = self._get_conn() + if status_filter: + rows = conn.execute("SELECT * FROM models WHERE status = ? ORDER BY discovered_at DESC", + (status_filter.value,)).fetchall() + else: + rows = conn.execute("SELECT * FROM models ORDER BY discovered_at DESC").fetchall() + return [self._row_to_record(r) for r in rows] + + def get_serving_model(self) -> Optional[ModelRecord]: + """Get the currently serving model, if any.""" + models = self.list_models(ModelStatus.serving) + return models[0] if models else None + + def promote(self, model_id: str) -> ModelRecord: + """Promote a model to serving, retiring any current serving model.""" + current_serving = self.get_serving_model() + if current_serving and current_serving.id != model_id: + self.retire(current_serving.id) + model = self.get_model(model_id) + if model is None: + raise ValueError(f"Model {model_id} not found") + if model.status == ModelStatus.serving: + return model + if model.status not in (ModelStatus.registered, ModelStatus.passed): + # Allow promoting from registered or passed + if model.status != ModelStatus.registered: + self.update_status(model_id, ModelStatus.registered) + return self.update_status(model_id, ModelStatus.serving) + + def retire(self, model_id: str) -> ModelRecord: + """Retire a model from serving.""" + model = self.get_model(model_id) + if model is None: + raise ValueError(f"Model {model_id} not found") + if model.status == ModelStatus.retired: + return model + if model.status != ModelStatus.serving: + raise ValueError(f"Can only retire serving models, got {model.status.value}") + return self.update_status(model_id, ModelStatus.retired) + + def _row_to_record(self, row: sqlite3.Row) -> ModelRecord: + """Convert a database row to a ModelRecord.""" + return ModelRecord( + id=row["id"], name=row["name"], gguf_path=row["gguf_path"], + quant=row["quant"], size_gb=row["size_gb"], + vram_required_gb=row["vram_required_gb"], + toolcall15_score=row["toolcall15_score"], + tok_s_gen=row["tok_s_gen"], tok_s_prompt=row["tok_s_prompt"], + source_url=row["source_url"], signal_score=row["signal_score"], + status=ModelStatus(row["status"]), + discovered_at=row["discovered_at"], validated_at=row["validated_at"], + registered_at=row["registered_at"], serving_since=row["serving_since"], + ) + + def close(self) -> None: + """Close the database connection.""" + if self._conn: + self._conn.close() + self._conn = None diff --git a/internal/inference/lifecycle/requirements.txt b/internal/inference/lifecycle/requirements.txt new file mode 100644 index 00000000..94a5f28d --- /dev/null +++ b/internal/inference/lifecycle/requirements.txt @@ -0,0 +1,2 @@ +fastapi>=0.100.0 +uvicorn>=0.23.0 diff --git a/internal/inference/lifecycle/serve.py b/internal/inference/lifecycle/serve.py new file mode 100644 index 00000000..a164017a --- /dev/null +++ b/internal/inference/lifecycle/serve.py @@ -0,0 +1,316 @@ +"""Serve manager — manages llama-server systemd service for model serving.""" + +import logging +import subprocess +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional + +from .registry import ModelStatus, Registry +from .hardware import compute_optimal_ngl, get_hardware_profile + +logger = logging.getLogger(__name__) + +LLAMA_SERVER_BINARY = Path.home() / "Development" / "llama-cpp-turboquant-cuda" / "build" / "bin" / "llama-server" +SERVICE_NAME = "obol-inference" +DEFAULT_PORT = 8080 +DEFAULT_CONTEXT = 8192 + + +@dataclass +class ServingStatus: + """Current serving status.""" + active: bool + model_id: Optional[str] + model_name: Optional[str] + port: int + healthy: bool + uptime_seconds: Optional[float] + pid: Optional[int] + + +def generate_systemd_unit(model_path: str, ngl: int, port: int = DEFAULT_PORT, + threads: int = 8, context_size: int = DEFAULT_CONTEXT) -> str: + """Generate a systemd unit file for llama-server. + + Args: + model_path: Absolute path to the GGUF model file. + ngl: Number of GPU layers to offload. + port: Port to listen on. + threads: Number of CPU threads. + context_size: Context window size. + + Returns: + Complete systemd unit file content as string. + """ + unit = f"""[Unit] +Description=Obol Inference Server (llama-server) +After=network.target +Wants=network-online.target + +[Service] +Type=simple +User={_get_current_user()} +Environment=CUDA_VISIBLE_DEVICES=0 +ExecStart={LLAMA_SERVER_BINARY} \\ + -m {model_path} \\ + -ngl {ngl} \\ + --port {port} \\ + -t {threads} \\ + -c {context_size} \\ + --host 0.0.0.0 \\ + --metrics \\ + --log-disable +Restart=on-failure +RestartSec=5 +LimitNOFILE=65536 +StandardOutput=journal +StandardError=journal + +[Install] +WantedBy=multi-user.target +""" + return unit + + +def _get_current_user() -> str: + """Get the current username.""" + import os + return os.environ.get("USER", "obol") + + +def install_service(unit_content: str, service_name: str = SERVICE_NAME) -> None: + """Install a systemd service unit file. + + Writes the unit file to /etc/systemd/system/, runs daemon-reload, + and enables the service. + + Args: + unit_content: Complete systemd unit file content. + service_name: Name of the systemd service. + + Raises: + RuntimeError: If installation fails. + """ + unit_path = f"/etc/systemd/system/{service_name}.service" + logger.info("Installing systemd service: %s", unit_path) + + # Write unit file (requires sudo) + proc = subprocess.run( + ["sudo", "tee", unit_path], + input=unit_content, capture_output=True, text=True + ) + if proc.returncode != 0: + raise RuntimeError(f"Failed to write unit file: {proc.stderr}") + + # Reload systemd + proc = subprocess.run(["sudo", "systemctl", "daemon-reload"], + capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError(f"daemon-reload failed: {proc.stderr}") + + # Enable service + proc = subprocess.run(["sudo", "systemctl", "enable", service_name], + capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError(f"Failed to enable service: {proc.stderr}") + + logger.info("Service %s installed and enabled", service_name) + + +def _health_check(port: int = DEFAULT_PORT, timeout: int = 30) -> bool: + """Check if llama-server is healthy. + + Args: + port: Port to check. + timeout: Maximum seconds to wait for health. + + Returns: + True if server responds healthy within timeout. + """ + for _ in range(timeout): + try: + result = subprocess.run( + ["curl", "-sf", f"http://localhost:{port}/health"], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + return True + except subprocess.TimeoutExpired: + pass + time.sleep(1) + return False + + +def start_serving(model_id: str, registry: Registry, + port: int = DEFAULT_PORT) -> ServingStatus: + """Start serving a model via systemd. + + Stops any currently serving model, installs the new service, + starts it, verifies health, and updates the registry. + + Args: + model_id: ID of the model to serve. + registry: Model registry instance. + port: Port to serve on. + + Returns: + ServingStatus of the newly started service. + + Raises: + RuntimeError: If serving fails to start. + """ + model = registry.get_model(model_id) + if model is None: + raise ValueError(f"Model {model_id} not found") + if not model.gguf_path: + raise ValueError(f"Model {model_id} has no GGUF path") + + # Stop current serving model + stop_serving(registry) + + # Generate and install service + hw = get_hardware_profile() + ngl = compute_optimal_ngl(model.size_gb or 4.0, hw.vram_gb) + unit = generate_systemd_unit( + model_path=model.gguf_path, ngl=ngl, + port=port, threads=hw.cpu_cores, context_size=DEFAULT_CONTEXT + ) + install_service(unit) + + # Start service + proc = subprocess.run(["sudo", "systemctl", "start", SERVICE_NAME], + capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError(f"Failed to start service: {proc.stderr}") + + # Health check + if not _health_check(port): + subprocess.run(["sudo", "systemctl", "stop", SERVICE_NAME], + capture_output=True, text=True) + raise RuntimeError("Service started but failed health check") + + # Update registry + registry.promote(model_id) + logger.info("Model %s (%s) now serving on port %d", model_id, model.name, port) + + return get_serving_status(registry) + + +def stop_serving(registry: Registry) -> None: + """Stop the current serving model. + + Args: + registry: Model registry instance. + """ + current = registry.get_serving_model() + proc = subprocess.run( + ["sudo", "systemctl", "stop", SERVICE_NAME], + capture_output=True, text=True + ) + if current: + try: + registry.retire(current.id) + except ValueError: + pass # Already retired + logger.info("Stopped inference service") + + +def hot_swap(new_model_id: str, registry: Registry, + rollback_on_failure: bool = True, + port: int = DEFAULT_PORT) -> ServingStatus: + """Hot-swap to a new model with rollback support. + + Stops the current model, starts the new one, and rolls back + if the new model fails health check. + + Args: + new_model_id: ID of the model to swap to. + registry: Model registry instance. + rollback_on_failure: Whether to rollback on failure. + port: Port to serve on. + + Returns: + ServingStatus after swap. + + Raises: + RuntimeError: If swap and rollback both fail. + """ + old_model = registry.get_serving_model() + old_model_id = old_model.id if old_model else None + + try: + return start_serving(new_model_id, registry, port) + except RuntimeError as e: + logger.error("Hot swap to %s failed: %s", new_model_id, e) + if rollback_on_failure and old_model_id: + logger.info("Rolling back to %s", old_model_id) + try: + return start_serving(old_model_id, registry, port) + except RuntimeError as rollback_err: + raise RuntimeError( + f"Hot swap failed AND rollback failed: {e} / {rollback_err}" + ) + raise + + +def get_serving_status(registry: Optional[Registry] = None) -> ServingStatus: + """Get current serving status from systemd. + + Args: + registry: Optional registry to look up model details. + + Returns: + ServingStatus with current state. + """ + # Check systemd service status + proc = subprocess.run( + ["systemctl", "is-active", SERVICE_NAME], + capture_output=True, text=True + ) + active = proc.stdout.strip() == "active" + + # Get PID + pid = None + if active: + pid_proc = subprocess.run( + ["systemctl", "show", SERVICE_NAME, "--property=MainPID", "--value"], + capture_output=True, text=True + ) + try: + pid = int(pid_proc.stdout.strip()) + except ValueError: + pass + + # Get model info from registry + model_id = None + model_name = None + if registry: + serving = registry.get_serving_model() + if serving: + model_id = serving.id + model_name = serving.name + + # Health check (non-blocking) + healthy = False + if active: + try: + h = subprocess.run( + ["curl", "-sf", f"http://localhost:{DEFAULT_PORT}/health"], + capture_output=True, text=True, timeout=3 + ) + healthy = h.returncode == 0 + except subprocess.TimeoutExpired: + pass + + return ServingStatus( + active=active, + model_id=model_id, + model_name=model_name, + port=DEFAULT_PORT, + healthy=healthy, + uptime_seconds=None, + pid=pid, + ) diff --git a/internal/inference/lifecycle/validate.py b/internal/inference/lifecycle/validate.py new file mode 100644 index 00000000..50dd01eb --- /dev/null +++ b/internal/inference/lifecycle/validate.py @@ -0,0 +1,347 @@ +"""Validation runner — downloads models, runs benchmarks, and evaluates tool-calling ability.""" + +import json +import logging +import os +import re +import signal +import subprocess +import time +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +from .registry import ModelStatus, Registry +from .hardware import HardwareProfile, compute_optimal_ngl +from .discover import DiscoveryCandidate + +logger = logging.getLogger(__name__) + +LLAMA_BENCH_BIN = Path.home() / "Development" / "llama-cpp-turboquant-cuda" / "build" / "bin" / "llama-bench" +LLAMA_SERVER_BIN = Path.home() / "Development" / "llama-cpp-turboquant-cuda" / "build" / "bin" / "llama-server" +MODEL_CACHE_DIR = Path.home() / ".cache" / "obol" / "models" +TOOLCALL15_DIR = Path.home() / "Development" / "toolcall-15" + + +@dataclass +class ValidationResult: + """Results from model validation pipeline.""" + model_id: str + tok_s_gen: float = 0.0 + tok_s_prompt: float = 0.0 + toolcall15_score: float = 0.0 + toolcall15_details: Dict[str, bool] = field(default_factory=dict) + passed: bool = False + timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + error: Optional[str] = None + + +def download_model(hf_url: str, dest_dir: Optional[Path] = None) -> Path: + """Download a model from HuggingFace using wget. + + Args: + hf_url: HuggingFace model URL (will resolve GGUF files). + dest_dir: Download destination directory. + + Returns: + Path to downloaded model file. + + Raises: + RuntimeError: If download fails. + """ + dest_dir = dest_dir or MODEL_CACHE_DIR + dest_dir.mkdir(parents=True, exist_ok=True) + + # Construct download URL for GGUF files from HF repo + if not hf_url.endswith(".gguf"): + # Try to find GGUF files in the repo + resolve_url = f"{hf_url}/resolve/main/" + logger.info("Looking for GGUF files at %s", hf_url) + # Attempt to list files via HF API + api_url = hf_url.replace("huggingface.co", "huggingface.co/api/models") + try: + result = subprocess.run( + ["curl", "-sL", api_url], + capture_output=True, text=True, timeout=15 + ) + if result.returncode == 0: + data = json.loads(result.stdout) + siblings = data.get("siblings", []) + gguf_files = [s["rfilename"] for s in siblings if s["rfilename"].endswith(".gguf")] + if gguf_files: + # Pick the first Q4_K_M or smallest file + preferred = [f for f in gguf_files if "Q4_K_M" in f] + chosen = preferred[0] if preferred else gguf_files[0] + hf_url = f"{hf_url}/resolve/main/{chosen}" + logger.info("Selected GGUF file: %s", chosen) + else: + raise RuntimeError(f"No GGUF files found in {hf_url}") + except (json.JSONDecodeError, subprocess.TimeoutExpired) as e: + raise RuntimeError(f"Failed to resolve GGUF files: {e}") + + filename = hf_url.split("/")[-1] + dest_path = dest_dir / filename + + if dest_path.exists(): + logger.info("Model already cached: %s", dest_path) + return dest_path + + logger.info("Downloading %s to %s", hf_url, dest_path) + result = subprocess.run( + ["wget", "-q", "--show-progress", "-O", str(dest_path), hf_url], + timeout=3600 # 1 hour max + ) + if result.returncode != 0: + dest_path.unlink(missing_ok=True) + raise RuntimeError(f"Download failed with code {result.returncode}") + + logger.info("Download complete: %s (%.1f GB)", dest_path, + dest_path.stat().st_size / (1024**3)) + return dest_path + + +def run_llama_bench(model_path: Path, ngl: int, threads: int = 8) -> Dict[str, float]: + """Run llama-bench and parse throughput results. + + Args: + model_path: Path to GGUF model file. + ngl: Number of GPU layers. + threads: CPU threads to use. + + Returns: + Dict with 'tok_s_gen' and 'tok_s_prompt' values. + + Raises: + RuntimeError: If benchmark fails to run or parse. + """ + if not LLAMA_BENCH_BIN.exists(): + raise RuntimeError(f"llama-bench not found at {LLAMA_BENCH_BIN}") + + logger.info("Running llama-bench: model=%s ngl=%d threads=%d", model_path, ngl, threads) + result = subprocess.run( + [str(LLAMA_BENCH_BIN), + "-m", str(model_path), + "-ngl", str(ngl), + "-t", str(threads), + "-p", "512", "-n", "128", + "-o", "json"], + capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + raise RuntimeError(f"llama-bench failed: {result.stderr[:500]}") + + try: + data = json.loads(result.stdout) + results = {"tok_s_gen": 0.0, "tok_s_prompt": 0.0} + for entry in data if isinstance(data, list) else [data]: + if entry.get("type") == "tg" or "tg" in str(entry.get("test", "")): + results["tok_s_gen"] = float(entry.get("avg_ts", 0)) + elif entry.get("type") == "pp" or "pp" in str(entry.get("test", "")): + results["tok_s_prompt"] = float(entry.get("avg_ts", 0)) + logger.info("Bench results: gen=%.1f tok/s, prompt=%.1f tok/s", + results["tok_s_gen"], results["tok_s_prompt"]) + return results + except (json.JSONDecodeError, KeyError) as e: + # Fallback: parse text output + gen_match = re.search(r"tg\s.*?([\d.]+)\s*±", result.stdout) + pp_match = re.search(r"pp\s.*?([\d.]+)\s*±", result.stdout) + results = { + "tok_s_gen": float(gen_match.group(1)) if gen_match else 0.0, + "tok_s_prompt": float(pp_match.group(1)) if pp_match else 0.0, + } + if results["tok_s_gen"] == 0 and results["tok_s_prompt"] == 0: + raise RuntimeError(f"Failed to parse llama-bench output: {e}") + return results + + +def run_toolcall15(model_path: Path, ngl: int, port: int = 9090) -> Dict: + """Run ToolCall-15 evaluation suite against a model. + + Starts a temporary llama-server, runs the ToolCall-15 eval suite, + and parses results from the SSE endpoint. + + Args: + model_path: Path to GGUF model file. + ngl: Number of GPU layers. + port: Port for temporary llama-server. + + Returns: + Dict with 'score' (int out of 15) and 'details' (per-scenario results). + """ + if not LLAMA_SERVER_BIN.exists(): + raise RuntimeError(f"llama-server not found at {LLAMA_SERVER_BIN}") + + server_proc = None + try: + # Start temporary llama-server + logger.info("Starting temp llama-server on port %d", port) + server_proc = subprocess.Popen( + [str(LLAMA_SERVER_BIN), + "-m", str(model_path), + "-ngl", str(ngl), + "--port", str(port), + "-c", "8192", + "--host", "0.0.0.0"], + stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Wait for health endpoint + healthy = False + for attempt in range(60): + try: + health = subprocess.run( + ["curl", "-sf", f"http://localhost:{port}/health"], + capture_output=True, text=True, timeout=5 + ) + if health.returncode == 0: + healthy = True + logger.info("llama-server healthy after %d seconds", attempt) + break + except subprocess.TimeoutExpired: + pass + time.sleep(1) + + if not healthy: + raise RuntimeError("llama-server failed to become healthy within 60s") + + # Run ToolCall-15 eval + logger.info("Starting ToolCall-15 evaluation") + eval_result = subprocess.run( + ["curl", "-sN", f"http://localhost:3001/api/eval/stream?port={port}"], + capture_output=True, text=True, timeout=600 + ) + + # Parse SSE events + score = 0 + details = {} + for line in eval_result.stdout.split("\n"): + if line.startswith("data:"): + try: + event_data = json.loads(line[5:].strip()) + if event_data.get("type") == "scenario_result": + scenario = event_data.get("scenario", "unknown") + passed = event_data.get("passed", False) + details[scenario] = passed + if passed: + score += 1 + except json.JSONDecodeError: + continue + + logger.info("ToolCall-15 score: %d/15", score) + return {"score": score, "details": details} + + finally: + if server_proc: + logger.info("Stopping temp llama-server (pid=%d)", server_proc.pid) + server_proc.terminate() + try: + server_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + server_proc.kill() + server_proc.wait() + + +def parse_sse_results(sse_text: str) -> Dict: + """Parse ToolCall-15 SSE output into score and details. + + Args: + sse_text: Raw SSE text from the eval stream. + + Returns: + Dict with 'score' and 'details'. + """ + score = 0 + details = {} + for line in sse_text.split("\n"): + line = line.strip() + if line.startswith("data:"): + try: + event_data = json.loads(line[5:].strip()) + if event_data.get("type") == "scenario_result": + scenario = event_data.get("scenario", "unknown") + passed = event_data.get("passed", False) + details[scenario] = passed + if passed: + score += 1 + except json.JSONDecodeError: + continue + return {"score": score, "details": details} + + +def validate_model(candidate: DiscoveryCandidate, + hardware: HardwareProfile, + registry: Registry) -> ValidationResult: + """Full validation pipeline: download, benchmark, evaluate, update registry. + + Args: + candidate: Discovery candidate to validate. + hardware: Current hardware profile. + registry: Model registry instance. + + Returns: + ValidationResult with all metrics. + """ + # Register the model in discovered state + record = registry.add_model( + name=candidate.name, + source_url=candidate.hf_url, + quant=candidate.quant, + size_gb=candidate.size_gb_est, + signal_score=candidate.signal_score, + ) + model_id = record.id + result = ValidationResult(model_id=model_id) + + try: + # Download + registry.update_status(model_id, ModelStatus.downloading) + if not candidate.hf_url: + raise RuntimeError("No HuggingFace URL for candidate") + model_path = download_model(candidate.hf_url) + + # Update path in registry + conn = registry._get_conn() + conn.execute("UPDATE models SET gguf_path = ? WHERE id = ?", + (str(model_path), model_id)) + conn.commit() + + # Validate + registry.update_status(model_id, ModelStatus.validating) + ngl = compute_optimal_ngl( + candidate.size_gb_est or 4.0, hardware.vram_gb + ) + + # Benchmark + bench = run_llama_bench(model_path, ngl, threads=hardware.cpu_cores) + result.tok_s_gen = bench["tok_s_gen"] + result.tok_s_prompt = bench["tok_s_prompt"] + + # ToolCall-15 + tc15 = run_toolcall15(model_path, ngl) + result.toolcall15_score = tc15["score"] + result.toolcall15_details = tc15["details"] + + # Record benchmark + registry.update_benchmark( + model_id, "toolcall-15", tc15["score"], + bench["tok_s_gen"], bench["tok_s_prompt"] + ) + + # Pass/fail threshold: >=10/15 score and >=10 tok/s gen + result.passed = (tc15["score"] >= 10 and bench["tok_s_gen"] >= 10.0) + new_status = ModelStatus.passed if result.passed else ModelStatus.failed + registry.update_status(model_id, new_status) + logger.info("Validation %s for %s: score=%d gen=%.1f", + "PASSED" if result.passed else "FAILED", + candidate.name, tc15["score"], bench["tok_s_gen"]) + + except Exception as e: + logger.error("Validation failed for %s: %s", model_id, e) + result.error = str(e) + try: + registry.update_status(model_id, ModelStatus.failed) + except ValueError: + pass # Already in failed state + + return result diff --git a/tests/test_inference_lifecycle.py b/tests/test_inference_lifecycle.py new file mode 100644 index 00000000..3f7e3963 --- /dev/null +++ b/tests/test_inference_lifecycle.py @@ -0,0 +1,348 @@ +"""Unit tests for inference lifecycle modules.""" + +import json +import os +import sys +import tempfile +import pytest + +# Add project root to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from internal.inference.lifecycle.registry import Registry, ModelStatus, ModelRecord +from internal.inference.lifecycle.hardware import compute_optimal_ngl, HardwareProfile +from internal.inference.lifecycle.discover import ( + score_signal, filter_candidates, parse_candidates, DiscoveryCandidate +) +from internal.inference.lifecycle.serve import generate_systemd_unit +from internal.inference.lifecycle.validate import parse_sse_results + + +class TestRegistryStateMachine: + """Test model lifecycle state transitions in the registry.""" + + def setup_method(self): + """Create a temporary database for each test.""" + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.tmp.close() + self.registry = Registry(db_path=self.tmp.name) + + def teardown_method(self): + """Clean up temporary database.""" + self.registry.close() + os.unlink(self.tmp.name) + + def test_registry_state_machine(self): + """Test full lifecycle: discovered -> downloading -> validating -> passed -> registered -> serving.""" + model = self.registry.add_model("test-model-7b", source_url="https://hf.co/test") + assert model.status == ModelStatus.discovered + assert model.id is not None + assert model.name == "test-model-7b" + + # discovered -> downloading + model = self.registry.update_status(model.id, ModelStatus.downloading) + assert model.status == ModelStatus.downloading + + # downloading -> validating + model = self.registry.update_status(model.id, ModelStatus.validating) + assert model.status == ModelStatus.validating + assert model.validated_at is not None + + # validating -> passed + model = self.registry.update_status(model.id, ModelStatus.passed) + assert model.status == ModelStatus.passed + + # passed -> registered + model = self.registry.update_status(model.id, ModelStatus.registered) + assert model.status == ModelStatus.registered + assert model.registered_at is not None + + # registered -> serving + model = self.registry.update_status(model.id, ModelStatus.serving) + assert model.status == ModelStatus.serving + assert model.serving_since is not None + + def test_invalid_transition_raises(self): + """Test that invalid state transitions raise ValueError.""" + model = self.registry.add_model("bad-transition") + with pytest.raises(ValueError, match="Invalid transition"): + self.registry.update_status(model.id, ModelStatus.serving) + + def test_registry_promote(self): + """Test that promoting a model retires the current serving model.""" + # Create and promote first model + m1 = self.registry.add_model("model-a") + self.registry.update_status(m1.id, ModelStatus.downloading) + self.registry.update_status(m1.id, ModelStatus.validating) + self.registry.update_status(m1.id, ModelStatus.passed) + self.registry.update_status(m1.id, ModelStatus.registered) + self.registry.update_status(m1.id, ModelStatus.serving) + + # Create and promote second model + m2 = self.registry.add_model("model-b") + self.registry.update_status(m2.id, ModelStatus.downloading) + self.registry.update_status(m2.id, ModelStatus.validating) + self.registry.update_status(m2.id, ModelStatus.passed) + self.registry.update_status(m2.id, ModelStatus.registered) + + result = self.registry.promote(m2.id) + assert result.status == ModelStatus.serving + + # First model should be retired + m1_updated = self.registry.get_model(m1.id) + assert m1_updated.status == ModelStatus.retired + + # Only one serving model + serving = self.registry.get_serving_model() + assert serving.id == m2.id + + def test_list_models_with_filter(self): + """Test listing models with status filter.""" + self.registry.add_model("alpha") + m2 = self.registry.add_model("beta") + self.registry.update_status(m2.id, ModelStatus.downloading) + + discovered = self.registry.list_models(ModelStatus.discovered) + assert len(discovered) == 1 + assert discovered[0].name == "alpha" + + downloading = self.registry.list_models(ModelStatus.downloading) + assert len(downloading) == 1 + assert downloading[0].name == "beta" + + all_models = self.registry.list_models() + assert len(all_models) == 2 + + def test_update_benchmark(self): + """Test recording benchmark results.""" + model = self.registry.add_model("bench-model") + self.registry.update_benchmark(model.id, "toolcall-15", 12.0, 45.5, 120.3) + + updated = self.registry.get_model(model.id) + assert updated.toolcall15_score == 12.0 + assert updated.tok_s_gen == 45.5 + assert updated.tok_s_prompt == 120.3 + + +class TestHardwareNGL: + """Test GPU layer computation heuristics.""" + + def test_compute_optimal_ngl_standard(self): + """Test NGL calculation for standard case.""" + # 24GB VRAM, 7GB model, 80 layers + ngl = compute_optimal_ngl(7.0, 24.0, total_layers=80) + expected = int((24.0 - 2.0) / 7.0 * 80) # ~251 -> capped at 251 + assert ngl == min(expected, 999) + assert ngl > 0 + + def test_compute_optimal_ngl_small_vram(self): + """Test NGL with insufficient VRAM.""" + ngl = compute_optimal_ngl(7.0, 2.0) + assert ngl == 0 + + def test_compute_optimal_ngl_zero_vram(self): + """Test NGL with no VRAM (CPU only).""" + ngl = compute_optimal_ngl(7.0, 0.0) + assert ngl == 0 + + def test_compute_optimal_ngl_large_model(self): + """Test NGL with model larger than VRAM.""" + # 8GB VRAM, 70GB model + ngl = compute_optimal_ngl(70.0, 8.0, total_layers=80) + assert ngl > 0 + assert ngl < 80 # Can't fit all layers + + def test_compute_optimal_ngl_cap_at_999(self): + """Test NGL is capped at 999.""" + ngl = compute_optimal_ngl(0.5, 48.0, total_layers=200) + assert ngl == 999 + + def test_compute_optimal_ngl_zero_model_size(self): + """Test NGL with zero model size.""" + ngl = compute_optimal_ngl(0.0, 24.0) + assert ngl == 0 + + +class TestSignalScoring: + """Test social signal scoring.""" + + def test_score_signal_basic(self): + """Test score calculation with all metrics.""" + tweet = {"likes": 100, "bookmarks": 50, "retweets": 30} + score = score_signal(tweet) + expected = 100 * 1.0 + 50 * 2.0 + 30 * 1.5 # 100 + 100 + 45 = 245 + assert score == expected + + def test_score_signal_alternative_keys(self): + """Test score with alternative key names.""" + tweet = {"like_count": 200, "bookmark_count": 10, "retweet_count": 40} + score = score_signal(tweet) + expected = 200 * 1.0 + 10 * 2.0 + 40 * 1.5 + assert score == expected + + def test_score_signal_missing_fields(self): + """Test score with missing engagement fields.""" + tweet = {"likes": 50} + score = score_signal(tweet) + assert score == 50.0 + + def test_score_signal_empty(self): + """Test score with empty tweet.""" + assert score_signal({}) == 0.0 + + +class TestCandidateFiltering: + """Test candidate filtering by hardware constraints.""" + + def _make_candidate(self, name: str, size_gb=None): + return DiscoveryCandidate( + name=name, hf_url=f"https://hf.co/{name}", + quant="Q4_K_M", size_gb_est=size_gb, + signal_score=100.0, source_tweet_id="123", + discovered_at="2025-01-01T00:00:00", + ) + + def test_filter_by_size(self): + """Test filtering out models too large for hardware.""" + candidates = [ + self._make_candidate("small", 3.0), + self._make_candidate("medium", 7.0), + self._make_candidate("large", 40.0), + ] + filtered = filter_candidates(candidates, max_size_gb=10.0) + assert len(filtered) == 2 + names = [c.name for c in filtered] + assert "small" in names + assert "medium" in names + assert "large" not in names + + def test_filter_keeps_unknown_size(self): + """Test that candidates with unknown size are kept.""" + candidates = [ + self._make_candidate("known", 5.0), + self._make_candidate("unknown", None), + ] + filtered = filter_candidates(candidates, max_size_gb=4.0) + assert len(filtered) == 1 # unknown kept, known=5 filtered out + # Wait, known=5 > 4, so filtered out. unknown=None kept. + names = [c.name for c in filtered] + assert "unknown" in names + + def test_filter_all_pass(self): + """Test when all candidates fit.""" + candidates = [self._make_candidate("a", 2.0), self._make_candidate("b", 3.0)] + filtered = filter_candidates(candidates, max_size_gb=50.0) + assert len(filtered) == 2 + + +class TestSystemdUnitGeneration: + """Test systemd unit file generation.""" + + def test_generate_unit_content(self): + """Test that generated unit file has correct content.""" + unit = generate_systemd_unit( + model_path="/models/test.gguf", + ngl=33, port=8080, threads=16, context_size=4096 + ) + assert "[Unit]" in unit + assert "[Service]" in unit + assert "[Install]" in unit + assert "/models/test.gguf" in unit + assert "-ngl 33" in unit + assert "--port 8080" in unit + assert "-t 16" in unit + assert "-c 4096" in unit + assert "Restart=on-failure" in unit + assert "WantedBy=multi-user.target" in unit + + def test_generate_unit_binary_path(self): + """Test that the unit references the correct binary.""" + unit = generate_systemd_unit("/m.gguf", ngl=0, port=9090, threads=4, context_size=2048) + assert "llama-server" in unit + assert "llama-cpp-turboquant-cuda" in unit + + +class TestValidationResultParsing: + """Test ToolCall-15 SSE output parsing.""" + + def test_parse_sse_all_pass(self): + """Test parsing SSE output where all scenarios pass.""" + sse = "\n".join([ + f'data: {{"type": "scenario_result", "scenario": "scenario_{i}", "passed": true}}' + for i in range(15) + ]) + result = parse_sse_results(sse) + assert result["score"] == 15 + assert len(result["details"]) == 15 + assert all(result["details"].values()) + + def test_parse_sse_mixed_results(self): + """Test parsing SSE output with mixed pass/fail.""" + lines = [] + for i in range(15): + passed = i < 10 # 10 pass, 5 fail + lines.append( + f'data: {{"type": "scenario_result", "scenario": "sc_{i}", "passed": {str(passed).lower()}}}' + ) + sse = "\n".join(lines) + result = parse_sse_results(sse) + assert result["score"] == 10 + assert sum(1 for v in result["details"].values() if v) == 10 + assert sum(1 for v in result["details"].values() if not v) == 5 + + def test_parse_sse_with_noise(self): + """Test parsing SSE with non-result events mixed in.""" + sse = """data: {"type": "status", "message": "starting"} +data: {"type": "scenario_result", "scenario": "weather", "passed": true} +data: {"type": "progress", "percent": 50} +data: {"type": "scenario_result", "scenario": "calendar", "passed": false} +data: {"type": "status", "message": "done"} +""" + result = parse_sse_results(sse) + assert result["score"] == 1 + assert result["details"]["weather"] is True + assert result["details"]["calendar"] is False + + def test_parse_sse_empty(self): + """Test parsing empty SSE output.""" + result = parse_sse_results("") + assert result["score"] == 0 + assert result["details"] == {} + + def test_parse_sse_malformed_json(self): + """Test parsing SSE with malformed JSON lines.""" + sse = """data: {"type": "scenario_result", "scenario": "ok", "passed": true} +data: {invalid json here +data: {"type": "scenario_result", "scenario": "also_ok", "passed": true} +""" + result = parse_sse_results(sse) + assert result["score"] == 2 + + +class TestParseCandidates: + """Test tweet parsing for model candidates.""" + + def test_parse_hf_url(self): + """Test extracting HuggingFace URLs from tweets.""" + tweets = json.dumps([{ + "id": "1", + "text": "New model released! Check out https://huggingface.co/TheBloke/Llama-2-7B-GGUF Q4_K_M quantization", + "likes": 50, "bookmarks": 10, "retweets": 5, + }]) + candidates = parse_candidates(tweets) + assert len(candidates) == 1 + assert candidates[0].hf_url == "https://huggingface.co/TheBloke/Llama-2-7B-GGUF" + assert candidates[0].quant == "Q4_K_M" + + def test_parse_deduplicates(self): + """Test that duplicate HF URLs are deduplicated.""" + tweets = json.dumps([ + {"id": "1", "text": "https://huggingface.co/user/model-GGUF great!", "likes": 10}, + {"id": "2", "text": "https://huggingface.co/user/model-GGUF amazing!", "likes": 20}, + ]) + candidates = parse_candidates(tweets) + assert len(candidates) == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])