-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpersistent_cache.py
More file actions
175 lines (157 loc) · 6.79 KB
/
Copy pathpersistent_cache.py
File metadata and controls
175 lines (157 loc) · 6.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""SQLite-backed cache for completed get_docs results across MCP restarts."""
from __future__ import annotations
import logging
import sqlite3
import threading
from pathlib import Path
from typing import NamedTuple
from pydantic import ValidationError
from mcp_server_python_docs.cache.codec import decode as decode_cache_payload
from mcp_server_python_docs.cache.codec import encode as encode_cache_payload
from mcp_server_python_docs.models import GetDocsResult
logger = logging.getLogger(__name__)
_NO_ANCHOR_KEY = "\x00mcp-python-docs:no-anchor\x00"
DEFAULT_RETRIEVED_DOCS_CACHE_CODEC = "zstd"
class CacheStats(NamedTuple):
hits: int = 0
misses: int = 0
writes: int = 0
class PersistentDocsCache:
"""Persist get_docs results by index fingerprint, version, and request identity."""
def __init__(
self,
cache_path: Path,
index_path: Path,
*,
default_codec: str = DEFAULT_RETRIEVED_DOCS_CACHE_CODEC,
) -> None:
self._cache_path = Path(cache_path)
self._default_codec = default_codec
# Set after fingerprint stat succeeds; stays "" if init fails so the
# cache disables cleanly without leaking partial state.
self._fingerprint = ""
self._hits = self._misses = self._writes = 0
# ``check_same_thread=False`` lets multiple threads share the connection,
# but per the Python sqlite3 docs writes must still be serialized by the
# application — this lock guards every execute()/commit() and the stats
# counters they update.
self._lock = threading.Lock()
self._conn: sqlite3.Connection | None = None
try:
self._fingerprint = self._fingerprint_index(Path(index_path))
self._cache_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(self._cache_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode = WAL")
self._conn.execute("PRAGMA synchronous = NORMAL")
self._conn.execute(
"CREATE TABLE IF NOT EXISTS retrieved_docs_cache ("
"index_fingerprint TEXT NOT NULL, version TEXT NOT NULL, slug TEXT NOT NULL, "
"anchor TEXT NOT NULL, max_chars INTEGER NOT NULL, start_index INTEGER NOT NULL, "
"result_json TEXT NOT NULL, compression TEXT NOT NULL DEFAULT 'none', "
"created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, "
"PRIMARY KEY (index_fingerprint, version, slug, anchor, max_chars, start_index))"
)
self._ensure_compression_column()
self._conn.execute(
"DELETE FROM retrieved_docs_cache WHERE index_fingerprint != ?",
(self._fingerprint,),
)
self._conn.commit()
except (OSError, sqlite3.Error) as e:
if self._conn is not None:
self._conn.close()
self._conn = None
logger.warning("Persistent docs cache disabled: %s", e)
@property
def cache_path(self) -> Path:
return self._cache_path
@staticmethod
def _fingerprint_index(index_path: Path) -> str:
stat = index_path.stat()
return f"{index_path.resolve()}:{stat.st_size}:{stat.st_mtime_ns}"
@staticmethod
def _anchor_key(anchor: str | None) -> str:
return _NO_ANCHOR_KEY if anchor is None else anchor
def _ensure_compression_column(self) -> None:
if self._conn is None:
return
columns = {
row[1] for row in self._conn.execute("PRAGMA table_info(retrieved_docs_cache)")
}
if "compression" not in columns:
self._conn.execute(
"ALTER TABLE retrieved_docs_cache "
"ADD COLUMN compression TEXT NOT NULL DEFAULT 'none'"
)
def stats(self) -> CacheStats:
return CacheStats(self._hits, self._misses, self._writes)
def get(
self, *, version: str, slug: str, anchor: str | None, max_chars: int, start_index: int
) -> GetDocsResult | None:
if self._conn is None:
with self._lock:
self._misses += 1
return None
with self._lock:
try:
row = self._conn.execute(
"SELECT result_json, compression FROM retrieved_docs_cache "
"WHERE index_fingerprint = ? "
"AND version = ? AND slug = ? AND anchor = ? AND max_chars = ? "
"AND start_index = ?",
(
self._fingerprint,
version,
slug,
self._anchor_key(anchor),
max_chars,
start_index,
),
).fetchone()
except sqlite3.Error as e:
self._misses += 1
logger.warning("Persistent docs cache read skipped: %s", e)
return None
if row is None:
self._misses += 1
return None
try:
payload = row[0].encode("utf-8") if isinstance(row[0], str) else bytes(row[0])
result_json = decode_cache_payload(payload, row[1])
result = GetDocsResult.model_validate_json(result_json)
except (ValidationError, ValueError, TypeError) as e:
self._misses += 1
logger.warning("Persistent docs cache entry ignored: %s", e)
return None
self._hits += 1
return result
def put(self, *, result: GetDocsResult, max_chars: int, start_index: int) -> None:
if self._conn is None:
return
with self._lock:
try:
self._conn.execute(
"INSERT OR REPLACE INTO retrieved_docs_cache "
"(index_fingerprint, version, slug, anchor, max_chars, start_index, "
"result_json, compression) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
self._fingerprint,
result.version,
result.slug,
self._anchor_key(result.anchor),
max_chars,
start_index,
encode_cache_payload(result.model_dump_json(), self._default_codec),
self._default_codec,
),
)
self._conn.commit()
except (sqlite3.Error, ValueError) as e:
logger.warning("Persistent docs cache write skipped: %s", e)
return
self._writes += 1
def close(self) -> None:
with self._lock:
if self._conn is not None:
self._conn.close()
self._conn = None