-
Notifications
You must be signed in to change notification settings - Fork 705
Expand file tree
/
Copy pathsqlbot_cache.py
More file actions
160 lines (141 loc) · 5.76 KB
/
sqlbot_cache.py
File metadata and controls
160 lines (141 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import re
from fastapi_cache import FastAPICache
from functools import partial, wraps
from typing import Optional, Any, Dict, Tuple
from inspect import signature
from common.core.config import settings
from common.utils.utils import SQLBotLogUtil
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache as original_cache
def custom_key_builder(
func: Any,
namespace: str = "",
*,
args: Tuple[Any, ...] = (),
kwargs: Dict[str, Any],
cacheName: str,
keyExpression: Optional[str] = None,
) -> str | list[str]:
try:
base_key = f"{namespace}:{cacheName}:"
if keyExpression:
sig = signature(func)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
# 支持args[0]格式
if keyExpression.startswith("args["):
if match := re.match(r"args\[(\d+)\]", keyExpression):
index = int(match.group(1))
value = bound_args.args[index]
if isinstance(value, list):
return [f"{base_key}{v}" for v in value]
return f"{base_key}{value}"
# 支持属性路径格式
parts = keyExpression.split('.')
if not bound_args.arguments.get(parts[0]):
return f"{base_key}{parts[0]}"
value = bound_args.arguments[parts[0]]
for part in parts[1:]:
value = getattr(value, part)
if isinstance(value, list):
return [f"{base_key}{v}" for v in value]
return f"{base_key}{value}"
# 默认使用第一个参数作为key
return f"{base_key}{args[0] if args else 'default'}"
except Exception as e:
SQLBotLogUtil.error(f"Key builder error: {str(e)}")
raise ValueError(f"Invalid cache key generation: {e}") from e
def cache(
expire: int = 60 * 60 * 24,
namespace: str = "",
*,
cacheName: str, # 必须提供cacheName
keyExpression: Optional[str] = None,
):
def decorator(func):
# 预先生成key builder
used_key_builder = partial(
custom_key_builder,
cacheName=cacheName,
keyExpression=keyExpression
)
@wraps(func)
async def wrapper(*args, **kwargs):
if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none" or not is_cache_initialized():
return await func(*args, **kwargs)
# 生成缓存键
cache_key = used_key_builder(
func=func,
namespace=str(namespace) if namespace else "",
args=args,
kwargs=kwargs
)
return await original_cache(
expire=expire,
namespace=str(namespace) if namespace else "",
key_builder=lambda *_, **__: cache_key
)(func)(*args, **kwargs)
return wrapper
return decorator
def clear_cache(
namespace: str = "",
*,
cacheName: str,
keyExpression: Optional[str] = None,
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if not settings.CACHE_TYPE or settings.CACHE_TYPE.lower() == "none" or not is_cache_initialized():
return await func(*args, **kwargs)
cache_key = custom_key_builder(
func=func,
namespace=str(namespace) if namespace else "",
args=args,
kwargs=kwargs,
cacheName=cacheName,
keyExpression=keyExpression,
)
key_list = cache_key if isinstance(cache_key, list) else [cache_key]
backend = FastAPICache.get_backend()
for temp_cache_key in key_list:
if await backend.get(temp_cache_key):
if settings.CACHE_TYPE.lower() == "redis":
redis = backend.redis
await redis.delete(temp_cache_key)
else:
await backend.clear(key=temp_cache_key)
SQLBotLogUtil.debug(f"Cache cleared: {temp_cache_key}")
return await func(*args, **kwargs)
return wrapper
return decorator
def init_sqlbot_cache():
cache_type: str = settings.CACHE_TYPE
if cache_type == "memory":
FastAPICache.init(InMemoryBackend())
SQLBotLogUtil.info("SQLBot 使用内存缓存, 仅支持单进程模式")
elif cache_type == "redis":
from fastapi_cache.backends.redis import RedisBackend
import redis.asyncio as redis
from redis.asyncio.connection import ConnectionPool
redis_url = settings.CACHE_REDIS_URL or "redis://localhost:6379/0"
pool = ConnectionPool.from_url(url=redis_url)
redis_client = redis.Redis(connection_pool=pool)
FastAPICache.init(RedisBackend(redis_client), prefix="sqlbot-cache")
SQLBotLogUtil.info(f"SQLBot 使用Redis缓存, 可使用多进程模式")
else:
SQLBotLogUtil.warning("SQLBot 未启用缓存, 可使用多进程模式")
def is_cache_initialized() -> bool:
# 检查必要的属性是否存在
if not hasattr(FastAPICache, "_backend") or not hasattr(FastAPICache, "_prefix"):
return False
# 检查属性值是否为 None
if FastAPICache._backend is None or FastAPICache._prefix is None:
return False
# 尝试获取后端确认
try:
backend = FastAPICache.get_backend()
return backend is not None
except (AssertionError, AttributeError, Exception) as e:
SQLBotLogUtil.debug(f"缓存初始化检查失败: {str(e)}")
return False