|
| 1 | +import asyncio |
1 | 2 | from redis import asyncio as aioredis |
2 | 3 | from packaging.version import parse as parse_version |
| 4 | +from weakref import WeakKeyDictionary |
3 | 5 |
|
4 | 6 | import bittensor.utils.btlogging as btul |
5 | | - |
6 | 7 | from subvortex.core.database.database_utils import decode_value |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class Database: |
10 | 11 | def __init__(self, settings): |
11 | 12 | self.models = {} |
12 | 13 | self.settings = settings |
13 | | - self.database = None |
| 14 | + self._clients = WeakKeyDictionary() # Cache clients per event loop |
14 | 15 |
|
15 | | - async def connect(self): |
16 | | - self.database = aioredis.StrictRedis( |
| 16 | + def _new_client(self): |
| 17 | + return aioredis.StrictRedis( |
17 | 18 | host=self.settings.database_host, |
18 | 19 | port=self.settings.database_port, |
19 | 20 | db=self.settings.database_index, |
20 | 21 | password=self.settings.database_password, |
21 | 22 | ) |
22 | 23 |
|
23 | | - btul.logging.info("Connected to Redis", prefix=self.settings.logging_name) |
| 24 | + def _get_loop(self): |
| 25 | + return asyncio.get_running_loop() |
| 26 | + |
| 27 | + async def get_client(self): |
| 28 | + loop = self._get_loop() |
| 29 | + |
| 30 | + if loop in self._clients: |
| 31 | + return self._clients[loop] |
| 32 | + |
| 33 | + client = self._new_client() |
| 34 | + self._clients[loop] = client |
| 35 | + |
| 36 | + btul.logging.info( |
| 37 | + "Created new Redis client for event loop", prefix=self.settings.logging_name |
| 38 | + ) |
| 39 | + return client |
24 | 40 |
|
25 | 41 | async def is_connection_alive(self) -> bool: |
| 42 | + client = await self.get_client() |
| 43 | + |
26 | 44 | try: |
27 | | - pong = await self.database.ping() |
| 45 | + pong = await client.ping() |
28 | 46 | return pong is True |
29 | 47 | except Exception as e: |
30 | | - btul.logging.warning(f"Redis connection check failed: {e}") |
| 48 | + btul.logging.warning( |
| 49 | + f"Redis connection check failed: {e}", prefix=self.settings.logging_name |
| 50 | + ) |
31 | 51 | return False |
32 | 52 |
|
33 | 53 | async def ensure_connection(self): |
34 | | - if self.database is None or not await self.is_connection_alive(): |
| 54 | + client = await self.get_client() |
| 55 | + |
| 56 | + if not await self.is_connection_alive(): |
35 | 57 | btul.logging.warning( |
36 | | - "Reconnecting to Redis...", |
| 58 | + "Redis ping failed, but client will be reused", |
37 | 59 | prefix=self.settings.logging_name, |
38 | 60 | ) |
39 | | - await self.connect() |
| 61 | + # You may optionally recreate here if needed |
40 | 62 |
|
41 | 63 | async def wait_until_ready(self, name: str): |
42 | | - # Ensure the connection is ip and running |
43 | 64 | await self.ensure_connection() |
44 | 65 |
|
| 66 | + client = await self.get_client() |
| 67 | + |
45 | 68 | message_key = self._key(f"state:{name}") |
46 | 69 | stream_key = self._key(f"state:{name}:stream") |
47 | 70 | last_id = "$" |
48 | 71 |
|
49 | 72 | try: |
50 | | - # Step 1: check the message key first |
51 | | - snapshot = await self.database.get(message_key) |
| 73 | + snapshot = await client.get(message_key) |
52 | 74 | if snapshot and snapshot.decode() == "ready": |
53 | | - btul.logging.trace( |
| 75 | + btul.logging.debug( |
54 | 76 | f"{name} is already ready (via message key)", |
55 | 77 | prefix=self.settings.logging_name, |
56 | 78 | ) |
57 | 79 | return |
58 | 80 |
|
59 | | - # Step 2: wait for stream messages |
60 | | - btul.logging.trace( |
| 81 | + btul.logging.debug( |
61 | 82 | f"Waiting on stream: {stream_key}", prefix=self.settings.logging_name |
62 | 83 | ) |
63 | 84 | while True: |
64 | | - entries = await self.database.xread({stream_key: last_id}, block=0) |
| 85 | + entries = await client.xread({stream_key: last_id}, block=0) |
65 | 86 | if not entries: |
66 | 87 | continue |
67 | 88 |
|
68 | 89 | for stream_key, messages in entries: |
69 | | - btul.logging.trace( |
| 90 | + btul.logging.debug( |
70 | 91 | f"Received stream message: {messages}", |
71 | 92 | prefix=self.settings.logging_name, |
72 | 93 | ) |
73 | 94 | for msg_id, fields in messages: |
74 | | - state = fields.get("state".encode(), b"").decode() |
| 95 | + state = fields.get(b"state", b"").decode() |
75 | 96 | if state == "ready": |
76 | | - btul.logging.trace( |
| 97 | + btul.logging.debug( |
77 | 98 | f"{name} is now ready (via stream)", |
78 | 99 | prefix=self.settings.logging_name, |
79 | 100 | ) |
80 | 101 | return |
81 | | - last_id = msg_id # move forward |
| 102 | + last_id = msg_id |
82 | 103 | except Exception as err: |
83 | 104 | btul.logging.warning( |
84 | 105 | f"Failed to read the state of {name}: {err}", |
85 | 106 | prefix=self.settings.logging_name, |
86 | 107 | ) |
87 | 108 |
|
88 | 109 | async def _get_migration_status(self, model_name: str): |
89 | | - """ |
90 | | - Returns: |
91 | | - - latest_version: the 'new' version |
92 | | - - active_versions: versions marked 'dual' or 'new', |
93 | | - or fallback to latest if none are active. |
94 | | - """ |
95 | | - # Ensure the connection is ip and running |
96 | 110 | await self.ensure_connection() |
| 111 | + |
| 112 | + client = await self.get_client() |
97 | 113 |
|
98 | 114 | latest = None |
99 | 115 | active = [] |
100 | 116 |
|
101 | 117 | all_versions = sorted(self.models[model_name].keys(), key=parse_version) |
102 | 118 |
|
103 | 119 | for version in all_versions: |
104 | | - mode = await self.database.get(f"migration_mode:{version}") |
| 120 | + mode = await client.get(f"migration_mode:{version}") |
105 | 121 | mode = decode_value(mode) |
106 | 122 |
|
107 | 123 | if mode == "new": |
|
0 commit comments