diff --git a/cmd/codebase-memory-mcp/cli_test.go b/cmd/codebase-memory-mcp/cli_test.go index 61be593..2993e1d 100644 --- a/cmd/codebase-memory-mcp/cli_test.go +++ b/cmd/codebase-memory-mcp/cli_test.go @@ -51,8 +51,17 @@ func testCmd(t *testing.T, args ...string) *exec.Cmd { } // testEnvWithHome returns env vars with HOME (and USERPROFILE on Windows) set. +// CLAUDE_CONFIG_DIR is stripped so that skills resolve under HOME/.claude. func testEnvWithHome(home string, extra ...string) []string { - env := append(os.Environ(), "HOME="+home) + base := os.Environ() + env := make([]string, 0, len(base)+2) + for _, e := range base { + if strings.HasPrefix(e, "CLAUDE_CONFIG_DIR=") { + continue // strip: tests expect paths under HOME/.claude + } + env = append(env, e) + } + env = append(env, "HOME="+home) if runtime.GOOS == "windows" { env = append(env, "USERPROFILE="+home) // On Windows, DLL lookup uses PATH. Tests that replace PATH with an @@ -259,6 +268,153 @@ func TestCLI_InstallForceOverwrites(t *testing.T) { } } +func TestCLI_InstallProject(t *testing.T) { + home := t.TempDir() + projDir := t.TempDir() + + cmd := testCmd(t, "install", "--project", projDir) + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir(), "SHELL=/bin/zsh") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("install --project failed: %v\n%s", err, out) + } + output := string(out) + + // Should write project-local .mcp.json + mcpJSON := filepath.Join(projDir, ".mcp.json") + data, err := os.ReadFile(mcpJSON) + if err != nil { + t.Fatalf("expected .mcp.json at %s: %v", mcpJSON, err) + } + if !strings.Contains(string(data), "codebase-memory-mcp") { + t.Fatalf("expected codebase-memory-mcp in .mcp.json, got: %s", data) + } + + // Should still install skills globally + skillFile := filepath.Join(home, ".claude", "skills", "codebase-memory-exploring", "SKILL.md") + if _, err := os.Stat(skillFile); err != nil { + t.Fatal("skills should be installed globally even with --project") + } + + // Should NOT contain editor registration output (Cursor, VS Code, etc.) + if strings.Contains(output, "[Cursor]") || strings.Contains(output, "[VS Code") { + t.Fatal("--project should skip global editor registrations") + } +} + +func TestCLI_InstallProjectDryRun(t *testing.T) { + home := t.TempDir() + projDir := t.TempDir() + + cmd := testCmd(t, "install", "--project", projDir, "--dry-run") + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir()) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("install --project --dry-run failed: %v\n%s", err, out) + } + output := string(out) + if !strings.Contains(output, "dry-run") { + t.Fatal("expected dry-run in output") + } + + // Should NOT create .mcp.json + mcpJSON := filepath.Join(projDir, ".mcp.json") + if _, err := os.Stat(mcpJSON); !os.IsNotExist(err) { + t.Fatal("dry-run should not create .mcp.json") + } +} + +func TestCLI_UninstallProject(t *testing.T) { + home := t.TempDir() + projDir := t.TempDir() + + // First install + cmd := testCmd(t, "install", "--project", projDir) + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir(), "SHELL=/bin/zsh") + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("install --project failed: %v\n%s", err, out) + } + + // Verify .mcp.json exists + mcpJSON := filepath.Join(projDir, ".mcp.json") + if _, err := os.Stat(mcpJSON); err != nil { + t.Fatal("expected .mcp.json after install") + } + + // Uninstall + cmd = testCmd(t, "uninstall", "--project", projDir) + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir()) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("uninstall --project failed: %v\n%s", err, out) + } + + // .mcp.json should be removed (it only had our entry) + if _, err := os.Stat(mcpJSON); !os.IsNotExist(err) { + t.Fatal(".mcp.json should be removed after uninstall --project (no other servers)") + } +} + +func TestCLI_UninstallProjectPreservesOtherServers(t *testing.T) { + projDir := t.TempDir() + home := t.TempDir() + + // Write .mcp.json with our entry + another server + mcpJSON := filepath.Join(projDir, ".mcp.json") + initial := `{ + "mcpServers": { + "codebase-memory-mcp": {"command": "/usr/bin/cmm"}, + "other-server": {"command": "/usr/bin/other"} + } +}` + if err := os.WriteFile(mcpJSON, []byte(initial), 0o600); err != nil { + t.Fatal(err) + } + + cmd := testCmd(t, "uninstall", "--project", projDir) + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir()) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("uninstall --project failed: %v\n%s", err, out) + } + + // .mcp.json should still exist with the other server + data, err := os.ReadFile(mcpJSON) + if err != nil { + t.Fatal("expected .mcp.json to still exist") + } + if strings.Contains(string(data), "codebase-memory-mcp") { + t.Fatal("our entry should be removed") + } + if !strings.Contains(string(data), "other-server") { + t.Fatal("other server entry should be preserved") + } +} + +func TestCLI_InstallCLAUDE_CONFIG_DIR(t *testing.T) { + home := t.TempDir() + customClaudeDir := filepath.Join(t.TempDir(), "custom-claude") + + cmd := testCmd(t, "install") + cmd.Env = testEnvWithHome(home, "PATH="+t.TempDir(), "SHELL=/bin/zsh", + "CLAUDE_CONFIG_DIR="+customClaudeDir) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("install with CLAUDE_CONFIG_DIR failed: %v\n%s", err, out) + } + + // Skills should be under the custom dir, not ~/.claude + skillFile := filepath.Join(customClaudeDir, "skills", "codebase-memory-exploring", "SKILL.md") + if _, err := os.Stat(skillFile); err != nil { + t.Fatalf("skills should be under CLAUDE_CONFIG_DIR (%s): %v", customClaudeDir, err) + } + + // Should NOT be under default ~/.claude + defaultSkill := filepath.Join(home, ".claude", "skills", "codebase-memory-exploring", "SKILL.md") + if _, err := os.Stat(defaultSkill); !os.IsNotExist(err) { + t.Fatal("skills should NOT be under default ~/.claude when CLAUDE_CONFIG_DIR is set") + } +} + func TestCLI_InstallPATHAppend(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("shell RC PATH append is Unix-specific") diff --git a/cmd/codebase-memory-mcp/install.go b/cmd/codebase-memory-mcp/install.go index 4379e0c..1392e78 100644 --- a/cmd/codebase-memory-mcp/install.go +++ b/cmd/codebase-memory-mcp/install.go @@ -14,21 +14,59 @@ import ( // installConfig holds settings for the install/uninstall commands. type installConfig struct { - dryRun bool - force bool + dryRun bool + force bool + project string // absolute path for project-local install; empty = global +} + +// claudeConfigDir returns the Claude Code configuration directory. +// Respects the CLAUDE_CONFIG_DIR environment variable; defaults to ~/.claude. +func claudeConfigDir() string { + if envDir := os.Getenv("CLAUDE_CONFIG_DIR"); envDir != "" { + return envDir + } + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".claude") } func runInstall(args []string) int { cfg := installConfig{} - for _, a := range args { - switch a { + for i := 0; i < len(args); i++ { + switch args[i] { case "--dry-run": cfg.dryRun = true case "--force": cfg.force = true + case "--project": + // Optional path argument — use next arg if it doesn't start with -- + if i+1 < len(args) && !strings.HasPrefix(args[i+1], "--") { + cfg.project = args[i+1] + i++ + } else { + // No path given — use cwd + cwd, err := os.Getwd() + if err != nil { + fmt.Fprintf(os.Stderr, "error: cannot determine cwd: %v\n", err) + return 1 + } + cfg.project = cwd + } } } + // Resolve relative project path to absolute + if cfg.project != "" && !filepath.IsAbs(cfg.project) { + abs, err := filepath.Abs(cfg.project) + if err != nil { + fmt.Fprintf(os.Stderr, "error: cannot resolve project path: %v\n", err) + return 1 + } + cfg.project = abs + } + binaryPath, err := detectBinaryPath() if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) @@ -44,40 +82,47 @@ func runInstall(args []string) int { // Skills (always installed — no CLI dependency) installSkills(cfg) - // Claude Code MCP registration - if claudePath := findCLI("claude"); claudePath != "" { - fmt.Printf("[Claude Code] detected (%s)\n", claudePath) - registerClaudeCodeMCP(binaryPath, claudePath, cfg) - } else { - fmt.Println("[Claude Code] not found — skipping MCP registration") + // Project-local .mcp.json (only when --project is specified) + if cfg.project != "" { + writeProjectMCPJSON(cfg.project, binaryPath, cfg) } - fmt.Println() + if cfg.project == "" { + // Claude Code MCP registration + if claudePath := findCLI("claude"); claudePath != "" { + fmt.Printf("[Claude Code] detected (%s)\n", claudePath) + registerClaudeCodeMCP(binaryPath, claudePath, cfg) + } else { + fmt.Println("[Claude Code] not found — skipping MCP registration") + } - // Codex CLI - if codexPath := findCLI("codex"); codexPath != "" { - fmt.Printf("[Codex CLI] detected (%s)\n", codexPath) - installCodex(binaryPath, codexPath, cfg) - } else { - fmt.Println("[Codex CLI] not found — skipping") - } + fmt.Println() - fmt.Println() + // Codex CLI + if codexPath := findCLI("codex"); codexPath != "" { + fmt.Printf("[Codex CLI] detected (%s)\n", codexPath) + installCodex(binaryPath, codexPath, cfg) + } else { + fmt.Println("[Codex CLI] not found — skipping") + } - // Cursor - installEditorMCP(binaryPath, cursorConfigPath(), "Cursor", cfg) + fmt.Println() - // Windsurf - installEditorMCP(binaryPath, windsurfConfigPath(), "Windsurf", cfg) + // Cursor + installEditorMCP(binaryPath, cursorConfigPath(), "Cursor", cfg) - // Gemini CLI (same mcpServers format as Cursor/Windsurf) - installEditorMCP(binaryPath, geminiConfigPath(), "Gemini CLI", cfg) + // Windsurf + installEditorMCP(binaryPath, windsurfConfigPath(), "Windsurf", cfg) - // VS Code Copilot (uses "servers" key with "type" field) - installVSCodeMCP(binaryPath, vscodeConfigPath(), cfg) + // Gemini CLI (same mcpServers format as Cursor/Windsurf) + installEditorMCP(binaryPath, geminiConfigPath(), "Gemini CLI", cfg) - // Zed (uses "context_servers" key with "source" field) - installZedMCP(binaryPath, zedConfigPath(), cfg) + // VS Code Copilot (uses "servers" key with "type" field) + installVSCodeMCP(binaryPath, vscodeConfigPath(), cfg) + + // Zed (uses "context_servers" key with "source" field) + installZedMCP(binaryPath, zedConfigPath(), cfg) + } fmt.Println("\nDone. Restart your editor/CLI to activate.") return 0 @@ -85,14 +130,44 @@ func runInstall(args []string) int { func runUninstall(args []string) int { cfg := installConfig{} - for _, a := range args { - if a == "--dry-run" { + for i := 0; i < len(args); i++ { + switch args[i] { + case "--dry-run": cfg.dryRun = true + case "--project": + if i+1 < len(args) && !strings.HasPrefix(args[i+1], "--") { + cfg.project = args[i+1] + i++ + } else { + cwd, err := os.Getwd() + if err != nil { + fmt.Fprintf(os.Stderr, "error: cannot determine cwd: %v\n", err) + return 1 + } + cfg.project = cwd + } + } + } + + // Resolve relative project path to absolute + if cfg.project != "" && !filepath.IsAbs(cfg.project) { + abs, err := filepath.Abs(cfg.project) + if err != nil { + fmt.Fprintf(os.Stderr, "error: cannot resolve project path: %v\n", err) + return 1 } + cfg.project = abs } fmt.Printf("\ncodebase-memory-mcp %s — uninstall\n\n", version) + // Project-local: only remove the .mcp.json entry + if cfg.project != "" { + removeProjectMCPJSON(cfg.project, cfg) + fmt.Println("\nDone.") + return 0 + } + // Remove Claude Code skills removeClaudeSkills(cfg) @@ -219,18 +294,18 @@ func detectShellRC() string { } } -// installSkills writes the 4 skill files to ~/.claude/skills/ and removes old monolithic skill. +// installSkills writes the 4 skill files to the Claude config skills dir and removes old monolithic skill. func installSkills(cfg installConfig) { - home, err := os.UserHomeDir() - if err != nil { - fmt.Printf(" ⚠ Cannot determine home directory: %v\n", err) + configDir := claudeConfigDir() + if configDir == "" { + fmt.Printf(" ⚠ Cannot determine Claude config directory\n") return } fmt.Println("[Skills]") // Remove old monolithic skill if it exists - oldSkillDir := filepath.Join(home, ".claude", "skills", "codebase-memory-mcp") + oldSkillDir := filepath.Join(configDir, "skills", "codebase-memory-mcp") if info, err := os.Stat(oldSkillDir); err == nil && info.IsDir() { if cfg.dryRun { fmt.Printf(" [dry-run] Would remove old skill: %s\n", oldSkillDir) @@ -243,7 +318,7 @@ func installSkills(cfg installConfig) { // Write 4 skill files for name, content := range skillFiles { - skillDir := filepath.Join(home, ".claude", "skills", name) + skillDir := filepath.Join(configDir, "skills", name) skillFile := filepath.Join(skillDir, "SKILL.md") if !cfg.force { @@ -359,14 +434,14 @@ func upsertCodexMCP(configFile, mcpSection, binaryPath string) error { // removeClaudeSkills removes all 4 skill directories. func removeClaudeSkills(cfg installConfig) { - home, err := os.UserHomeDir() - if err != nil { + configDir := claudeConfigDir() + if configDir == "" { return } fmt.Println("[Skills]") for name := range skillFiles { - skillDir := filepath.Join(home, ".claude", "skills", name) + skillDir := filepath.Join(configDir, "skills", name) if _, err := os.Stat(skillDir); os.IsNotExist(err) { continue } @@ -382,6 +457,105 @@ func removeClaudeSkills(cfg installConfig) { } } +// writeProjectMCPJSON writes a project-local .mcp.json for Claude Code session-level registration. +func writeProjectMCPJSON(projectDir, binaryPath string, cfg installConfig) { + configPath := filepath.Join(projectDir, ".mcp.json") + fmt.Printf("[Claude Code] Project-local MCP config: %s\n", configPath) + + if cfg.dryRun { + fmt.Printf(" [dry-run] Would write project-local .mcp.json at %s\n", configPath) + return + } + + // Read existing or start fresh + root := make(map[string]any) + if data, err := os.ReadFile(configPath); err == nil { + _ = json.Unmarshal(data, &root) + } + + servers, ok := root["mcpServers"].(map[string]any) + if !ok { + servers = make(map[string]any) + } + servers[mcpServerKey] = map[string]any{ + "command": binaryPath, + "args": []string{}, + } + root["mcpServers"] = servers + + out, err := json.MarshalIndent(root, "", " ") + if err != nil { + fmt.Printf(" ⚠ marshal JSON: %v\n", err) + return + } + if err := os.WriteFile(configPath, append(out, '\n'), 0o600); err != nil { + fmt.Printf(" ⚠ write %s: %v\n", configPath, err) + return + } + fmt.Printf(" ✓ Project-local .mcp.json written\n") +} + +// removeProjectMCPJSON removes the codebase-memory-mcp entry from a project-local .mcp.json. +func removeProjectMCPJSON(projectDir string, cfg installConfig) { + configPath := filepath.Join(projectDir, ".mcp.json") + fmt.Printf("[Claude Code] Remove project-local MCP entry: %s\n", configPath) + + if cfg.dryRun { + fmt.Printf(" [dry-run] Would remove %s entry from %s\n", mcpServerKey, configPath) + return + } + + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + fmt.Printf(" ○ No .mcp.json found at %s\n", configPath) + } else { + fmt.Printf(" ⚠ read %s: %v\n", configPath, err) + } + return + } + + var root map[string]any + if err := json.Unmarshal(data, &root); err != nil { + fmt.Printf(" ⚠ invalid JSON in %s: %v\n", configPath, err) + return + } + + servers, ok := root["mcpServers"].(map[string]any) + if !ok { + fmt.Printf(" ○ No mcpServers section in %s\n", configPath) + return + } + if _, exists := servers[mcpServerKey]; !exists { + fmt.Printf(" ○ %s not registered in %s\n", mcpServerKey, configPath) + return + } + + delete(servers, mcpServerKey) + + // If mcpServers is now empty, remove the file entirely + if len(servers) == 0 { + if err := os.Remove(configPath); err != nil { + fmt.Printf(" ⚠ remove %s: %v\n", configPath, err) + } else { + fmt.Printf(" ✓ Removed %s (no other servers)\n", configPath) + } + return + } + + root["mcpServers"] = servers + out, err := json.MarshalIndent(root, "", " ") + if err != nil { + fmt.Printf(" ⚠ marshal JSON: %v\n", err) + return + } + if err := os.WriteFile(configPath, append(out, '\n'), 0o600); err != nil { + fmt.Printf(" ⚠ write %s: %v\n", configPath, err) + return + } + fmt.Printf(" ✓ Removed %s entry from %s\n", mcpServerKey, configPath) +} + // deregisterMCP removes the MCP server registration from a CLI. func deregisterMCP(cliPath, cliName string, cfg installConfig) { if cfg.dryRun { diff --git a/internal/metrics/savings.go b/internal/metrics/savings.go new file mode 100644 index 0000000..01a5891 --- /dev/null +++ b/internal/metrics/savings.go @@ -0,0 +1,164 @@ +package metrics + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "log/slog" + "math" + "os" + "path/filepath" + "sync" + "time" +) + +// TokenMetadata holds per-call token savings estimation. +// Attached as _meta in tool responses when metrics are enabled. +type TokenMetadata struct { + TokensSaved int `json:"tokens_saved"` + BaselineTokens int `json:"baseline_tokens"` + ResponseTokens int `json:"response_tokens"` + CostAvoided float64 `json:"cost_avoided"` + ReductionRatio float64 `json:"reduction_ratio"` +} + +// EstimateTokens approximates token count from byte length using +// the heuristic of 1 token ≈ 4 bytes (accurate for ASCII-heavy source code). +func EstimateTokens(s string) int { + return len(s) / 4 +} + +// CalculateSavings computes token savings for a single tool call. +// baselineBytes: byte count of all files the user would have read manually. +// responseBytes: byte count of the actual tool response. +// pricePerToken: USD cost per output token (e.g. 0.000015 for Claude Sonnet). +func CalculateSavings(baselineBytes, responseBytes int, pricePerToken float64) TokenMetadata { + baseline := baselineBytes / 4 + response := responseBytes / 4 + saved := baseline - response + if saved < 0 { + saved = 0 + } + ratio := 0.0 + if baseline > 0 { + ratio = math.Round(float64(response)/float64(baseline)*1000) / 1000 + } + return TokenMetadata{ + TokensSaved: saved, + BaselineTokens: baseline, + ResponseTokens: response, + CostAvoided: math.Round(float64(saved)*pricePerToken*1e6) / 1e6, + ReductionRatio: ratio, + } +} + +// savingsRecord is the on-disk format for cumulative savings. +type savingsRecord struct { + InstallID string `json:"install_id"` + TotalTokensSaved int64 `json:"total_tokens_saved"` + TotalCostAvoided float64 `json:"total_cost_avoided"` + LastUpdated string `json:"last_updated"` +} + +// Tracker accumulates token savings across calls and persists totals. +// Use Snapshot() to read totals safely — do not access fields directly. +type Tracker struct { + mu sync.Mutex + path string + installID string + totalTokensSaved int64 + totalCostAvoided float64 +} + +// NewTracker loads or creates savings.json at path. Generates a random +// InstallID on first run. Never returns an error (fail-open). +func NewTracker(path string) *Tracker { + t := &Tracker{path: path} + + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + slog.Warn("metrics: failed to read savings file", "path", path, "err", err) + } + t.installID = newInstallID() + return t + } + + var rec savingsRecord + if err := json.Unmarshal(data, &rec); err != nil { + slog.Warn("metrics: malformed savings file, starting fresh", "path", path, "err", err) + t.installID = newInstallID() + return t + } + + t.installID = rec.InstallID + t.totalTokensSaved = rec.TotalTokensSaved + t.totalCostAvoided = rec.TotalCostAvoided + return t +} + +// Record atomically increments TotalTokensSaved and TotalCostAvoided +// by the values in meta, then persists to file. +func (t *Tracker) Record(meta TokenMetadata) { + t.mu.Lock() + defer t.mu.Unlock() + + t.totalTokensSaved += int64(meta.TokensSaved) + t.totalCostAvoided += meta.CostAvoided + + rec := savingsRecord{ + InstallID: t.installID, + TotalTokensSaved: t.totalTokensSaved, + TotalCostAvoided: t.totalCostAvoided, + LastUpdated: time.Now().UTC().Format(time.RFC3339), + } + + data, err := json.MarshalIndent(rec, "", " ") + if err != nil { + slog.Warn("metrics: failed to marshal savings", "err", err) + return + } + + dir := filepath.Dir(t.path) + tmp, err := os.CreateTemp(dir, "savings-*.json.tmp") + if err != nil { + slog.Warn("metrics: failed to create temp file", "dir", dir, "err", err) + return + } + tmpName := tmp.Name() + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpName) + slog.Warn("metrics: failed to write temp file", "err", err) + return + } + if err := tmp.Close(); err != nil { + os.Remove(tmpName) + slog.Warn("metrics: failed to close temp file", "err", err) + return + } + + if err := os.Rename(tmpName, t.path); err != nil { + os.Remove(tmpName) + slog.Warn("metrics: failed to rename temp file", "src", tmpName, "dst", t.path, "err", err) + return + } + + slog.Debug("metrics: savings persisted", "path", t.path, "total_tokens_saved", t.totalTokensSaved) +} + +// Snapshot returns current cumulative totals under lock. +func (t *Tracker) Snapshot() (totalTokensSaved int64, totalCostAvoided float64) { + t.mu.Lock() + defer t.mu.Unlock() + return t.totalTokensSaved, t.totalCostAvoided +} + +func newInstallID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "unknown" + } + return hex.EncodeToString(b) +} diff --git a/internal/metrics/savings_test.go b/internal/metrics/savings_test.go new file mode 100644 index 0000000..0896314 --- /dev/null +++ b/internal/metrics/savings_test.go @@ -0,0 +1,166 @@ +package metrics + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestEstimateTokens(t *testing.T) { + if got := EstimateTokens("hello"); got != 1 { + t.Fatalf("EstimateTokens(\"hello\") = %d, want 1", got) + } + if got := EstimateTokens(""); got != 0 { + t.Fatalf("EstimateTokens(\"\") = %d, want 0", got) + } +} + +func TestCalculateSavings(t *testing.T) { + tests := []struct { + name string + baselineBytes int + responseBytes int + pricePerToken float64 + wantTokensSaved int + wantReductionRatio float64 + }{ + { + name: "NormalCase", + baselineBytes: 4000, + responseBytes: 400, + pricePerToken: 0.000015, + wantTokensSaved: 900, + wantReductionRatio: 0.1, + }, + { + name: "NoSavings", + baselineBytes: 100, + responseBytes: 200, + pricePerToken: 0.000015, + wantTokensSaved: 0, + wantReductionRatio: 2.0, + }, + { + name: "ZeroBaseline", + baselineBytes: 0, + responseBytes: 400, + pricePerToken: 0.000015, + wantTokensSaved: 0, + wantReductionRatio: 0.0, + }, + { + name: "ZeroBoth", + baselineBytes: 0, + responseBytes: 0, + pricePerToken: 0.000015, + wantTokensSaved: 0, + wantReductionRatio: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CalculateSavings(tt.baselineBytes, tt.responseBytes, tt.pricePerToken) + if got.TokensSaved != tt.wantTokensSaved { + t.Errorf("TokensSaved = %d, want %d", got.TokensSaved, tt.wantTokensSaved) + } + if got.ReductionRatio != tt.wantReductionRatio { + t.Errorf("ReductionRatio = %v, want %v", got.ReductionRatio, tt.wantReductionRatio) + } + }) + } +} + +func TestCalculateSavings_NormalCaseFields(t *testing.T) { + got := CalculateSavings(4000, 400, 0.000015) + if got.BaselineTokens != 1000 { + t.Errorf("BaselineTokens = %d, want 1000", got.BaselineTokens) + } + if got.ResponseTokens != 100 { + t.Errorf("ResponseTokens = %d, want 100", got.ResponseTokens) + } + if got.CostAvoided != 0.0135 { + t.Errorf("CostAvoided = %v, want 0.0135", got.CostAvoided) + } +} + +func TestTracker_RecordAndPersist(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "savings.json") + + tr := NewTracker(path) + meta := TokenMetadata{TokensSaved: 500, CostAvoided: 0.0075} + tr.Record(meta) + tr.Record(meta) + + totalTokens, totalCost := tr.Snapshot() + if totalTokens != 1000 { + t.Errorf("TotalTokensSaved = %d, want 1000", totalTokens) + } + if totalCost != 0.015 { + t.Errorf("TotalCostAvoided = %v, want 0.015", totalCost) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read savings file: %v", err) + } + var rec savingsRecord + if err := json.Unmarshal(data, &rec); err != nil { + t.Fatalf("savings file is not valid JSON: %v", err) + } + if rec.TotalTokensSaved != 1000 { + t.Errorf("file TotalTokensSaved = %d, want 1000", rec.TotalTokensSaved) + } + if rec.InstallID == "" { + t.Error("file InstallID should not be empty") + } + if rec.LastUpdated == "" { + t.Error("file LastUpdated should not be empty") + } +} + +func TestTracker_LoadExisting(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "savings.json") + + existing := savingsRecord{ + InstallID: "abc123", + TotalTokensSaved: 5000, + TotalCostAvoided: 0.075, + LastUpdated: "2026-01-01T00:00:00Z", + } + data, _ := json.Marshal(existing) + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + tr := NewTracker(path) + if tr.installID != "abc123" { + t.Errorf("InstallID = %q, want %q", tr.installID, "abc123") + } + + tr.Record(TokenMetadata{TokensSaved: 200}) + totalTokens, _ := tr.Snapshot() + if totalTokens != 5200 { + t.Errorf("TotalTokensSaved = %d, want 5200", totalTokens) + } +} + +func TestTracker_MissingDir(t *testing.T) { + path := filepath.Join(t.TempDir(), "nonexistent-subdir", "savings.json") + + tr := NewTracker(path) + totalTokens, totalCost := tr.Snapshot() + if totalTokens != 0 || totalCost != 0 { + t.Errorf("expected zero snapshot for missing dir tracker, got (%d, %v)", totalTokens, totalCost) + } + + // Record should not panic even if dir doesn't exist + tr.Record(TokenMetadata{TokensSaved: 100}) + totalTokens, _ = tr.Snapshot() + if totalTokens != 100 { + t.Errorf("TotalTokensSaved after Record = %d, want 100", totalTokens) + } +} diff --git a/internal/store/config.go b/internal/store/config.go index 76523ae..9bff771 100644 --- a/internal/store/config.go +++ b/internal/store/config.go @@ -89,6 +89,19 @@ func (c *ConfigStore) GetInt(key string, defaultVal int) int { return n } +// GetFloat64 returns a float64 config value. +func (c *ConfigStore) GetFloat64(key string, defaultVal float64) float64 { + raw := c.Get(key, "") + if raw == "" { + return defaultVal + } + f, err := strconv.ParseFloat(raw, 64) + if err != nil { + return defaultVal + } + return f +} + // Set stores a key-value pair (upsert). func (c *ConfigStore) Set(key, value string) error { _, err := c.db.ExecContext(context.Background(), @@ -143,4 +156,24 @@ const ( // Accepts human-readable sizes: "2G", "512M", "4096M". // Default: empty (no limit). Applied on server startup. ConfigMemLimit = "mem_limit" + + // ConfigMetricsEnabled controls whether token savings estimation is computed + // and included in tool responses. Default: true. + // Set to false with: codebase-memory-mcp config set metrics_enabled false + ConfigMetricsEnabled = "metrics_enabled" + + // ConfigPricingModel selects the token pricing model for cost estimation. + // Supported values: "claude-sonnet" (default), "claude-opus", "gpt-4o", "custom". + // When "custom", ConfigCustomPricePerToken is used. + ConfigPricingModel = "pricing_model" + + // ConfigCustomPricePerToken is the USD cost per output token used when + // ConfigPricingModel is "custom". Example: 0.000015 for $15/M tokens. + ConfigCustomPricePerToken = "custom_price_per_token" + + // ConfigMetricsPath overrides the default savings.json path for per-project installs. + // When set, token savings accumulate in this file instead of the global cache. + // Path may be absolute or relative (resolved from server working directory). + // Default: empty (uses ~/.cache/codebase-memory-mcp/savings.json). + ConfigMetricsPath = "metrics_path" ) diff --git a/internal/store/config_test.go b/internal/store/config_test.go new file mode 100644 index 0000000..6cb7ad3 --- /dev/null +++ b/internal/store/config_test.go @@ -0,0 +1,34 @@ +package store + +import ( + "testing" +) + +func TestGetFloat64(t *testing.T) { + cfg, err := OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer cfg.Close() + + // Default when key is missing. + if got := cfg.GetFloat64("missing_key", 3.14); got != 3.14 { + t.Errorf("GetFloat64 missing key = %v, want 3.14", got) + } + + // Valid float value. + if err := cfg.Set("price", "0.000015"); err != nil { + t.Fatal(err) + } + if got := cfg.GetFloat64("price", 0); got != 0.000015 { + t.Errorf("GetFloat64 valid = %v, want 0.000015", got) + } + + // Invalid value falls back to default. + if err := cfg.Set("bad", "not-a-number"); err != nil { + t.Fatal(err) + } + if got := cfg.GetFloat64("bad", 99.9); got != 99.9 { + t.Errorf("GetFloat64 invalid = %v, want 99.9", got) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 3ac0244..b46f2e9 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -51,7 +51,15 @@ type Edge struct { } // cacheDir returns the default cache directory for databases. +// Respects CODEBASE_MEMORY_DB_DIR environment variable; defaults to ~/.cache/codebase-memory-mcp. func cacheDir() (string, error) { + // Allow override for per-project installations + if envDir := os.Getenv("CODEBASE_MEMORY_DB_DIR"); envDir != "" { + if err := os.MkdirAll(envDir, 0o750); err != nil { + return "", fmt.Errorf("mkdir cache (from env): %w", err) + } + return envDir, nil + } home, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("home dir: %w", err) diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 865448a..a4e4689 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -3,9 +3,52 @@ package store import ( "context" "fmt" + "os" + "path/filepath" "testing" ) +func TestCacheDir_EnvOverride(t *testing.T) { + customDir := filepath.Join(t.TempDir(), "custom-db-dir") + + // Set env, call cacheDir, restore + old := os.Getenv("CODEBASE_MEMORY_DB_DIR") + t.Setenv("CODEBASE_MEMORY_DB_DIR", customDir) + defer os.Setenv("CODEBASE_MEMORY_DB_DIR", old) + + dir, err := cacheDir() + if err != nil { + t.Fatalf("cacheDir() with env override: %v", err) + } + if dir != customDir { + t.Errorf("cacheDir() = %q, want %q", dir, customDir) + } + + // Directory should have been created + info, err := os.Stat(customDir) + if err != nil { + t.Fatalf("expected dir to exist: %v", err) + } + if !info.IsDir() { + t.Fatal("expected a directory") + } +} + +func TestCacheDir_Default(t *testing.T) { + // Unset env to test default path + t.Setenv("CODEBASE_MEMORY_DB_DIR", "") + + dir, err := cacheDir() + if err != nil { + t.Fatalf("cacheDir() default: %v", err) + } + home, _ := os.UserHomeDir() + expected := filepath.Join(home, ".cache", "codebase-memory-mcp") + if dir != expected { + t.Errorf("cacheDir() = %q, want %q", dir, expected) + } +} + func TestOpenMemory(t *testing.T) { s, err := OpenMemory() if err != nil { diff --git a/internal/tools/search.go b/internal/tools/search.go index 980a4f4..6e44c3e 100644 --- a/internal/tools/search.go +++ b/internal/tools/search.go @@ -2,8 +2,12 @@ package tools import ( "context" + "encoding/json" "fmt" + "os" + "path/filepath" + "github.com/DeusData/codebase-memory-mcp/internal/metrics" "github.com/DeusData/codebase-memory-mcp/internal/store" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -100,7 +104,45 @@ func (s *Server) handleSearchGraph(_ context.Context, req *mcp.CallToolRequest) } s.addIndexStatus(responseData) - result := jsonResult(responseData) + result := s.searchResultWithMeta(responseData, output.Results, st, projName) s.addUpdateNotice(result) return result, nil } + +// searchResultWithMeta computes token savings for search_graph and returns a wrapped result. +// Baseline = sum of unique source file sizes referenced in results. +// Falls back to jsonResult when metrics are disabled or project root is unavailable. +func (s *Server) searchResultWithMeta(responseData map[string]any, results []*store.SearchResult, st *store.Store, projName string) *mcp.CallToolResult { + if s.config != nil && !s.config.GetBool(store.ConfigMetricsEnabled, true) { + return jsonResult(responseData) + } + proj, _ := st.GetProject(projName) + if proj == nil { + return jsonResult(responseData) + } + baselineBytes := uniqueFileBytes(results, proj.RootPath) + price := priceForConfig(s.config) + responseJSON, _ := json.Marshal(responseData) + meta := metrics.CalculateSavings(baselineBytes, len(responseJSON), price) + return resultWithMeta(responseData, meta, s.metricsTracker) +} + +// uniqueFileBytes sums the sizes of unique source files referenced in search results. +func uniqueFileBytes(results []*store.SearchResult, rootPath string) int { + seen := make(map[string]struct{}, len(results)) + total := 0 + for _, r := range results { + if r.Node.FilePath == "" { + continue + } + if _, ok := seen[r.Node.FilePath]; ok { + continue + } + seen[r.Node.FilePath] = struct{}{} + absPath := filepath.Join(rootPath, r.Node.FilePath) + if fi, err := os.Stat(absPath); err == nil { + total += int(fi.Size()) + } + } + return total +} diff --git a/internal/tools/search_test.go b/internal/tools/search_test.go new file mode 100644 index 0000000..489dea8 --- /dev/null +++ b/internal/tools/search_test.go @@ -0,0 +1,83 @@ +package tools + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + + "github.com/DeusData/codebase-memory-mcp/internal/metrics" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func callSearchGraph(t *testing.T, srv *Server, namePattern string) map[string]any { + t.Helper() + args := map[string]any{"name_pattern": namePattern} + rawArgs, _ := json.Marshal(args) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "search_graph", + Arguments: rawArgs, + }, + } + + result, err := srv.handleSearchGraph(context.TODO(), req) + if err != nil { + t.Fatalf("handleSearchGraph error: %v", err) + } + if len(result.Content) == 0 { + t.Fatal("empty result content") + } + tc, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("expected TextContent, got %T", result.Content[0]) + } + var data map[string]any + if err := json.Unmarshal([]byte(tc.Text), &data); err != nil { + t.Fatalf("unmarshal result: %v (text: %s)", err, tc.Text) + } + return data +} + +func TestSearchGraph_MetaField(t *testing.T) { + srv := testSnippetServer(t) + // Attach a metricsTracker backed by a temp file. + savingsPath := filepath.Join(t.TempDir(), "savings.json") + srv.metricsTracker = metrics.NewTracker(savingsPath) + + // Search for "Handle" which matches HandleRequest in the fixture. + data := callSearchGraph(t, srv, "Handle") + + // Results should be present. + results, ok := data["results"].([]any) + if !ok || len(results) == 0 { + t.Fatal("expected at least one result from search_graph") + } + + // _meta must exist. + metaRaw, ok := data["_meta"] + if !ok { + t.Fatal("expected _meta field in response") + } + meta, ok := metaRaw.(map[string]any) + if !ok { + t.Fatalf("expected _meta to be a map, got %T", metaRaw) + } + + tokensSaved, _ := meta["tokens_saved"].(float64) + baselineTokens, _ := meta["baseline_tokens"].(float64) + responseTokens, _ := meta["response_tokens"].(float64) + + if tokensSaved < 0 { + t.Errorf("tokens_saved should be >= 0, got %v", tokensSaved) + } + // baseline_tokens may be 0 if the fixture file is not accessible on disk + // (test uses a temp dir, not the real source tree), so only assert >= 0. + if baselineTokens < 0 { + t.Errorf("baseline_tokens should be >= 0, got %v", baselineTokens) + } + if responseTokens <= 0 { + t.Errorf("response_tokens should be > 0, got %v", responseTokens) + } +} diff --git a/internal/tools/snippet.go b/internal/tools/snippet.go index 6da31d5..2bfe9fa 100644 --- a/internal/tools/snippet.go +++ b/internal/tools/snippet.go @@ -3,12 +3,14 @@ package tools import ( "bufio" "context" + "encoding/json" "fmt" "log/slog" "os" "path/filepath" "strings" + "github.com/DeusData/codebase-memory-mcp/internal/metrics" "github.com/DeusData/codebase-memory-mcp/internal/store" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -154,6 +156,15 @@ func (s *Server) buildSnippetResponse(match *snippetMatch, includeNeighbors bool responseData["alternatives"] = alternatives } + // Token savings: baseline = full source file size, response = marshaled JSON. + if s.config == nil || s.config.GetBool(store.ConfigMetricsEnabled, true) { + if fi, statErr := os.Stat(absPath); statErr == nil { + price := priceForConfig(s.config) + responseJSON, _ := json.Marshal(responseData) + meta := metrics.CalculateSavings(int(fi.Size()), len(responseJSON), price) + return resultWithMeta(responseData, meta, s.metricsTracker), nil + } + } return jsonResult(responseData), nil } diff --git a/internal/tools/snippet_test.go b/internal/tools/snippet_test.go index 9252e7c..dae1c9e 100644 --- a/internal/tools/snippet_test.go +++ b/internal/tools/snippet_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "testing" + "github.com/DeusData/codebase-memory-mcp/internal/metrics" "github.com/DeusData/codebase-memory-mcp/internal/store" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -488,3 +489,45 @@ func TestSnippet_IncludeNeighbors_Enabled(t *testing.T) { t.Error("expected Run in callee_names") } } + +func TestGetCodeSnippet_MetaField(t *testing.T) { + srv := testSnippetServer(t) + // Attach a metricsTracker backed by a temp file. + savingsPath := filepath.Join(t.TempDir(), "savings.json") + srv.metricsTracker = metrics.NewTracker(savingsPath) + + data := callSnippet(t, srv, "test-project.cmd.server.main.HandleRequest") + + // Source should still be present. + if data["source"] == nil || data["source"] == "" { + t.Fatal("expected non-empty source") + } + + // _meta must exist. + metaRaw, ok := data["_meta"] + if !ok { + t.Fatal("expected _meta field in response") + } + meta, ok := metaRaw.(map[string]any) + if !ok { + t.Fatalf("expected _meta to be a map, got %T", metaRaw) + } + + tokensSaved, _ := meta["tokens_saved"].(float64) + baselineTokens, _ := meta["baseline_tokens"].(float64) + responseTokens, _ := meta["response_tokens"].(float64) + reductionRatio, _ := meta["reduction_ratio"].(float64) + + if tokensSaved < 0 { + t.Errorf("tokens_saved should be >= 0, got %v", tokensSaved) + } + if baselineTokens <= 0 { + t.Errorf("baseline_tokens should be > 0, got %v", baselineTokens) + } + if responseTokens <= 0 { + t.Errorf("response_tokens should be > 0, got %v", responseTokens) + } + if reductionRatio <= 0 { + t.Errorf("reduction_ratio should be > 0, got %v", reductionRatio) + } +} diff --git a/internal/tools/tools.go b/internal/tools/tools.go index 5f1fcfd..d339150 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -18,6 +18,7 @@ import ( "time" "github.com/DeusData/codebase-memory-mcp/internal/discover" + "github.com/DeusData/codebase-memory-mcp/internal/metrics" "github.com/DeusData/codebase-memory-mcp/internal/pipeline" "github.com/DeusData/codebase-memory-mcp/internal/store" "github.com/DeusData/codebase-memory-mcp/internal/watcher" @@ -44,6 +45,8 @@ type Server struct { indexMu sync.Mutex handlers map[string]mcp.ToolHandler + metricsTracker *metrics.Tracker // nil when metrics_enabled=false + // Session-aware fields (set once via sync.Once, then immutable) sessionOnce sync.Once sessionRoot string // absolute path from client @@ -72,6 +75,9 @@ func NewServer(r *store.StoreRouter, opts ...ServerOption) *Server { opt(srv) } + // Initialize metrics tracker if enabled (default true). + srv.initMetricsTracker() + srv.mcp = mcp.NewServer( &mcp.Implementation{ Name: "codebase-memory-mcp", @@ -788,6 +794,72 @@ func (s *Server) registerProjectTools() { }, s.handleIndexStatus) } +// initMetricsTracker initializes the metrics tracker if metrics_enabled=true (default). +func (s *Server) initMetricsTracker() { + enabled := true + if s.config != nil { + enabled = s.config.GetBool(store.ConfigMetricsEnabled, true) + } + if !enabled { + return + } + + // Check for a config-driven override (set by per-project install) + savingsPath := "" + if s.config != nil { + savingsPath = s.config.Get(store.ConfigMetricsPath, "") + } + + // Fall back to default global path + if savingsPath == "" { + home, err := os.UserHomeDir() + if err != nil { + slog.Warn("metrics: cannot determine home dir, metrics disabled", "err", err) + return + } + savingsPath = filepath.Join(home, ".cache", "codebase-memory-mcp", "savings.json") + } + + s.metricsTracker = metrics.NewTracker(savingsPath) +} + +// resultWithMeta wraps data with a _meta field containing token savings estimation. +// If tracker is non-nil, also records the savings cumulatively. +// +// Currently instrumented: search_graph, get_code_snippet. +// NOT instrumented (intentionally): trace_call_path, query_graph, search_code, +// get_architecture, get_graph_schema, detect_changes, index_repository, +// list_projects, index_status, manage_adr, delete_project. +// Rationale: validating the _meta pattern on high-value tools first; tools like +// index_repository and list_projects have no file-read equivalent so savings = 0. +// Use this wrapper to add more tools incrementally. +func resultWithMeta(data map[string]any, meta metrics.TokenMetadata, tracker *metrics.Tracker) *mcp.CallToolResult { + data["_meta"] = meta + if tracker != nil { + tracker.Record(meta) + } + return jsonResult(data) +} + +// priceForConfig returns the USD price-per-token for the configured pricing model. +// Defaults to Claude Sonnet pricing ($15/M output tokens). +func priceForConfig(cfg *store.ConfigStore) float64 { + if cfg == nil { + return 0.000015 // claude-sonnet default + } + model := cfg.Get(store.ConfigPricingModel, "claude-sonnet") + switch model { + case "claude-opus": + return 0.000075 // $75/M output tokens + case "gpt-4o": + return 0.000010 // $10/M output tokens + case "custom": + return cfg.GetFloat64(store.ConfigCustomPricePerToken, 0.000015) + default: // "claude-sonnet" + return 0.000015 + } +} + // --- Helpers --- // jsonResult marshals data to JSON and returns as tool result. diff --git a/internal/tools/tools_test.go b/internal/tools/tools_test.go index b9e1a3a..8533867 100644 --- a/internal/tools/tools_test.go +++ b/internal/tools/tools_test.go @@ -1,8 +1,13 @@ package tools import ( + "os" + "path/filepath" "runtime" "testing" + + "github.com/DeusData/codebase-memory-mcp/internal/metrics" + "github.com/DeusData/codebase-memory-mcp/internal/store" ) func TestParseFileURI(t *testing.T) { @@ -51,6 +56,120 @@ func TestParseFileURI(t *testing.T) { } } +func TestPriceForConfig(t *testing.T) { + tests := []struct { + name string + cfg func(t *testing.T) *store.ConfigStore + want float64 + }{ + {"nil config", func(t *testing.T) *store.ConfigStore { return nil }, 0.000015}, + {"default (claude-sonnet)", func(t *testing.T) *store.ConfigStore { + c, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { c.Close() }) + return c + }, 0.000015}, + {"claude-opus", func(t *testing.T) *store.ConfigStore { + c, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { c.Close() }) + c.Set(store.ConfigPricingModel, "claude-opus") + return c + }, 0.000075}, + {"gpt-4o", func(t *testing.T) *store.ConfigStore { + c, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { c.Close() }) + c.Set(store.ConfigPricingModel, "gpt-4o") + return c + }, 0.000010}, + {"custom", func(t *testing.T) *store.ConfigStore { + c, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { c.Close() }) + c.Set(store.ConfigPricingModel, "custom") + c.Set(store.ConfigCustomPricePerToken, "0.00005") + return c + }, 0.00005}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := priceForConfig(tt.cfg(t)) + if got != tt.want { + t.Errorf("priceForConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestInitMetricsTracker(t *testing.T) { + t.Run("default enabled", func(t *testing.T) { + router, err := store.NewRouterWithDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(router.CloseAll) + srv := NewServer(router) + if srv.metricsTracker == nil { + t.Error("expected metricsTracker to be non-nil by default") + } + }) + + t.Run("disabled via config", func(t *testing.T) { + router, err := store.NewRouterWithDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(router.CloseAll) + cfg, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cfg.Close() }) + cfg.Set(store.ConfigMetricsEnabled, "false") + + srv := NewServer(router, WithConfig(cfg)) + if srv.metricsTracker != nil { + t.Error("expected metricsTracker to be nil when disabled") + } + }) + + t.Run("custom metrics_path via config", func(t *testing.T) { + router, err := store.NewRouterWithDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(router.CloseAll) + cfg, err := store.OpenConfigInDir(t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cfg.Close() }) + + customPath := filepath.Join(t.TempDir(), "custom-savings.json") + cfg.Set(store.ConfigMetricsPath, customPath) + + srv := NewServer(router, WithConfig(cfg)) + if srv.metricsTracker == nil { + t.Fatal("expected metricsTracker to be non-nil with custom path") + } + + // Record something and verify it wrote to the custom path + srv.metricsTracker.Record(metrics.TokenMetadata{TokensSaved: 100, CostAvoided: 0.001}) + if _, err := os.Stat(customPath); err != nil { + t.Errorf("expected savings file at custom path %s: %v", customPath, err) + } + }) +} + // windowsPath converts forward slashes to backslashes for Windows comparison. func windowsPath(p string) string { result := make([]byte, len(p))