diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index fcb4743cb3..f0c3cb8f3a 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -3,7 +3,7 @@ import asyncio import json import logging -import threading +import sqlite3 from contextlib import closing from pathlib import Path from typing import Any, Union, cast @@ -56,71 +56,70 @@ def _init_structure_tables(self): Creates the message_structure and turn_usage tables with appropriate indexes for conversation branching and usage analytics. """ - conn = self._get_connection() - - # Message structure with branch support - conn.execute(f""" - CREATE TABLE IF NOT EXISTS message_structure ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL, - message_id INTEGER NOT NULL, - branch_id TEXT NOT NULL DEFAULT 'main', - message_type TEXT NOT NULL, - sequence_number INTEGER NOT NULL, - user_turn_number INTEGER, - branch_turn_number INTEGER, - tool_name TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) - REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE, - FOREIGN KEY (message_id) - REFERENCES {self.messages_table}(id) ON DELETE CASCADE - ) - """) - - # Turn-level usage tracking with branch support and full JSON details - conn.execute(f""" - CREATE TABLE IF NOT EXISTS turn_usage ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - session_id TEXT NOT NULL, - branch_id TEXT NOT NULL DEFAULT 'main', - user_turn_number INTEGER NOT NULL, - requests INTEGER DEFAULT 0, - input_tokens INTEGER DEFAULT 0, - output_tokens INTEGER DEFAULT 0, - total_tokens INTEGER DEFAULT 0, - input_tokens_details JSON, - output_tokens_details JSON, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) - REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE, - UNIQUE(session_id, branch_id, user_turn_number) - ) - """) - - # Indexes - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_structure_session_seq - ON message_structure(session_id, sequence_number) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_structure_branch - ON message_structure(session_id, branch_id) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_structure_turn - ON message_structure(session_id, branch_id, user_turn_number) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_structure_branch_seq - ON message_structure(session_id, branch_id, sequence_number) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn - ON turn_usage(session_id, branch_id, user_turn_number) - """) - - conn.commit() + with self._locked_connection() as conn: + # Message structure with branch support + conn.execute(f""" + CREATE TABLE IF NOT EXISTS message_structure ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_id INTEGER NOT NULL, + branch_id TEXT NOT NULL DEFAULT 'main', + message_type TEXT NOT NULL, + sequence_number INTEGER NOT NULL, + user_turn_number INTEGER, + branch_turn_number INTEGER, + tool_name TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) + REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE, + FOREIGN KEY (message_id) + REFERENCES {self.messages_table}(id) ON DELETE CASCADE + ) + """) + + # Turn-level usage tracking with branch support and full JSON details + conn.execute(f""" + CREATE TABLE IF NOT EXISTS turn_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + branch_id TEXT NOT NULL DEFAULT 'main', + user_turn_number INTEGER NOT NULL, + requests INTEGER DEFAULT 0, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + input_tokens_details JSON, + output_tokens_details JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) + REFERENCES {self.sessions_table}(session_id) ON DELETE CASCADE, + UNIQUE(session_id, branch_id, user_turn_number) + ) + """) + + # Indexes + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_session_seq + ON message_structure(session_id, sequence_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_branch + ON message_structure(session_id, branch_id) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_turn + ON message_structure(session_id, branch_id, user_turn_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_structure_branch_seq + ON message_structure(session_id, branch_id, sequence_number) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_turn_usage_session_turn + ON turn_usage(session_id, branch_id, user_turn_number) + """) + + conn.commit() async def add_items(self, items: list[TResponseInputItem]) -> None: """Add items to the session. @@ -128,12 +127,34 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: Args: items: The items to add to the session """ - # Add to base table first - await super().add_items(items) + if not items: + return + + def _add_items_sync(): + """Synchronous helper to add items and structure metadata together.""" + with self._locked_connection() as conn: + # Keep both writes in one critical section so message IDs and metadata stay aligned. + self._insert_items(conn, items) + conn.commit() + try: + self._insert_structure_metadata(conn, items) + conn.commit() + except Exception as e: + conn.rollback() + self._logger.error( + f"Failed to add structure metadata for session {self.session_id}: {e}" + ) + try: + deleted_count = self._cleanup_orphaned_messages_sync(conn) + if deleted_count: + conn.commit() + else: + conn.rollback() + except Exception as cleanup_error: + conn.rollback() + self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}") - # Extract structure metadata with precise sequencing - if items: - await self._add_structure_metadata(items) + await asyncio.to_thread(_add_items_sync) async def get_items( self, @@ -157,9 +178,7 @@ async def get_items( # Get all items for this branch def _get_all_items_sync(): """Synchronous helper to get all items for a branch.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: with closing(conn.cursor()) as cursor: if session_limit is None: cursor.execute( @@ -202,9 +221,7 @@ def _get_all_items_sync(): def _get_items_sync(): """Synchronous helper to get items for a specific branch.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: with closing(conn.cursor()) as cursor: # Get message IDs in correct order for this branch if session_limit is None: @@ -273,19 +290,19 @@ def _get_next_turn_number(self, branch_id: str) -> int: Returns: The next available turn number for the specified branch. """ - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(user_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, branch_id), - ) - result = cursor.fetchone() - max_turn = result[0] if result else 0 - return max_turn + 1 + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 def _get_next_branch_turn_number(self, branch_id: str) -> int: """Get the next branch turn number for a specific branch. @@ -296,19 +313,19 @@ def _get_next_branch_turn_number(self, branch_id: str) -> int: Returns: The next available branch turn number for the specified branch. """ - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(branch_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, branch_id), - ) - result = cursor.fetchone() - max_turn = result[0] if result else 0 - return max_turn + 1 + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(branch_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 def _get_current_turn_number(self) -> int: """Get the current turn number for the current branch. @@ -316,18 +333,18 @@ def _get_current_turn_number(self) -> int: Returns: The current turn number for the active branch. """ - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(user_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, self._current_branch_id), - ) - result = cursor.fetchone() - return result[0] if result else 0 + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, self._current_branch_id), + ) + result = cursor.fetchone() + return result[0] if result else 0 async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None: """Extract structure metadata with branch-aware turn tracking. @@ -344,89 +361,9 @@ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None def _add_structure_sync(): """Synchronous helper to add structure metadata to database.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): - # Get the IDs of messages we just inserted, in order - with closing(conn.cursor()) as cursor: - cursor.execute( - f"SELECT id FROM {self.messages_table} " - f"WHERE session_id = ? ORDER BY id DESC LIMIT ?", - (self.session_id, len(items)), - ) - message_ids = [row[0] for row in cursor.fetchall()] - message_ids.reverse() # Match order of items - - # Get current max sequence number (global) - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(sequence_number), 0) - FROM message_structure - WHERE session_id = ? - """, - (self.session_id,), - ) - seq_start = cursor.fetchone()[0] - - # Get current turn numbers atomically with a single query - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT - COALESCE(MAX(user_turn_number), 0) as max_global_turn, - COALESCE(MAX(branch_turn_number), 0) as max_branch_turn - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, self._current_branch_id), - ) - result = cursor.fetchone() - current_turn = result[0] if result else 0 - current_branch_turn = result[1] if result else 0 - - # Process items and assign turn numbers correctly - structure_data = [] - user_message_count = 0 - - for i, (item, msg_id) in enumerate(zip(items, message_ids)): - msg_type = self._classify_message_type(item) - tool_name = self._extract_tool_name(item) - - # If this is a user message, increment turn counters - if self._is_user_message(item): - user_message_count += 1 - item_turn = current_turn + user_message_count - item_branch_turn = current_branch_turn + user_message_count - else: - # Non-user messages inherit the turn number of the most recent user message - item_turn = current_turn + user_message_count - item_branch_turn = current_branch_turn + user_message_count - - structure_data.append( - ( - self.session_id, - msg_id, - self._current_branch_id, - msg_type, - seq_start + i + 1, # Global sequence - item_turn, # Global turn number - item_branch_turn, # Branch-specific turn number - tool_name, - ) - ) - - with closing(conn.cursor()) as cursor: - cursor.executemany( - """ - INSERT INTO message_structure - (session_id, message_id, branch_id, message_type, sequence_number, - user_turn_number, branch_turn_number, tool_name) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - structure_data, - ) - conn.commit() + with self._locked_connection() as conn: + self._insert_structure_metadata(conn, items) + conn.commit() try: await asyncio.to_thread(_add_structure_sync) @@ -441,6 +378,94 @@ def _add_structure_sync(): self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}") # Don't re-raise - structure metadata is supplementary + def _insert_structure_metadata( + self, + conn: sqlite3.Connection, + items: list[TResponseInputItem], + ) -> None: + # Get the IDs of messages we just inserted, in order. + with closing(conn.cursor()) as cursor: + cursor.execute( + f"SELECT id FROM {self.messages_table} " + f"WHERE session_id = ? ORDER BY id DESC LIMIT ?", + (self.session_id, len(items)), + ) + message_ids = [row[0] for row in cursor.fetchall()] + message_ids.reverse() + + if len(message_ids) != len(items): + raise RuntimeError( + "Failed to resolve inserted message IDs while writing structure metadata" + ) + + # Get current max sequence number (global). + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(sequence_number), 0) + FROM message_structure + WHERE session_id = ? + """, + (self.session_id,), + ) + seq_start = cursor.fetchone()[0] + + # Get current turn numbers atomically with a single query. + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + COALESCE(MAX(user_turn_number), 0) as max_global_turn, + COALESCE(MAX(branch_turn_number), 0) as max_branch_turn + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, self._current_branch_id), + ) + result = cursor.fetchone() + current_turn = result[0] if result else 0 + current_branch_turn = result[1] if result else 0 + + # Process items and assign turn numbers correctly. + structure_data = [] + user_message_count = 0 + + for i, (item, msg_id) in enumerate(zip(items, message_ids)): + msg_type = self._classify_message_type(item) + tool_name = self._extract_tool_name(item) + + if self._is_user_message(item): + user_message_count += 1 + item_turn = current_turn + user_message_count + item_branch_turn = current_branch_turn + user_message_count + else: + item_turn = current_turn + user_message_count + item_branch_turn = current_branch_turn + user_message_count + + structure_data.append( + ( + self.session_id, + msg_id, + self._current_branch_id, + msg_type, + seq_start + i + 1, + item_turn, + item_branch_turn, + tool_name, + ) + ) + + with closing(conn.cursor()) as cursor: + cursor.executemany( + """ + INSERT INTO message_structure + (session_id, message_id, branch_id, message_type, sequence_number, + user_turn_number, branch_turn_number, tool_name) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + structure_data, + ) + async def _cleanup_orphaned_messages(self) -> int: """Remove messages that exist in the configured message table but not in message_structure. @@ -450,40 +475,43 @@ async def _cleanup_orphaned_messages(self) -> int: def _cleanup_sync(): """Synchronous helper to cleanup orphaned messages.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): - with closing(conn.cursor()) as cursor: - # Find messages without structure metadata - cursor.execute( - f""" - SELECT am.id - FROM {self.messages_table} am - LEFT JOIN message_structure ms ON am.id = ms.message_id - WHERE am.session_id = ? AND ms.message_id IS NULL - """, - (self.session_id,), - ) + with self._locked_connection() as conn: + deleted_count = self._cleanup_orphaned_messages_sync(conn) + if deleted_count: + conn.commit() + else: + conn.rollback() + return deleted_count - orphaned_ids = [row[0] for row in cursor.fetchall()] + return await asyncio.to_thread(_cleanup_sync) - if orphaned_ids: - # Delete orphaned messages - placeholders = ",".join("?" * len(orphaned_ids)) - cursor.execute( - f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})", - orphaned_ids, - ) + def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int: + with closing(conn.cursor()) as cursor: + # Find messages without structure metadata. + cursor.execute( + f""" + SELECT am.id + FROM {self.messages_table} am + LEFT JOIN message_structure ms ON am.id = ms.message_id + WHERE am.session_id = ? AND ms.message_id IS NULL + """, + (self.session_id,), + ) - deleted_count = cursor.rowcount - conn.commit() + orphaned_ids = [row[0] for row in cursor.fetchall()] - self._logger.info(f"Cleaned up {deleted_count} orphaned messages") - return deleted_count + if not orphaned_ids: + return 0 - return 0 + placeholders = ",".join("?" * len(orphaned_ids)) + cursor.execute( + f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})", + orphaned_ids, + ) - return await asyncio.to_thread(_cleanup_sync) + deleted_count = cursor.rowcount + self._logger.info(f"Cleaned up {deleted_count} orphaned messages") + return deleted_count def _classify_message_type(self, item: TResponseInputItem) -> str: """Classify the type of a message item. @@ -588,32 +616,32 @@ async def create_branch_from_turn( # Validate the turn exists and contains a user message def _validate_turn(): """Synchronous helper to validate turn exists and contains user message.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - f""" - SELECT am.message_data - FROM message_structure ms - JOIN {self.messages_table} am ON ms.message_id = am.id - WHERE ms.session_id = ? AND ms.branch_id = ? - AND ms.branch_turn_number = ? AND ms.message_type = 'user' - """, - (self.session_id, self._current_branch_id, turn_number), - ) - - result = cursor.fetchone() - if not result: - raise ValueError( - f"Turn {turn_number} does not contain a user message " - f"in branch '{self._current_branch_id}'" + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + f""" + SELECT am.message_data + FROM message_structure ms + JOIN {self.messages_table} am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.branch_turn_number = ? AND ms.message_type = 'user' + """, + (self.session_id, self._current_branch_id, turn_number), ) - message_data = result[0] - try: - content = json.loads(message_data).get("content", "") - return content[:50] + "..." if len(content) > 50 else content - except Exception: - return "Unable to parse content" + result = cursor.fetchone() + if not result: + raise ValueError( + f"Turn {turn_number} does not contain a user message " + f"in branch '{self._current_branch_id}'" + ) + + message_data = result[0] + try: + content = json.loads(message_data).get("content", "") + return content[:50] + "..." if len(content) > 50 else content + except Exception: + return "Unable to parse content" turn_content = await asyncio.to_thread(_validate_turn) @@ -670,19 +698,19 @@ async def switch_to_branch(self, branch_id: str) -> None: # Validate branch exists def _validate_branch(): """Synchronous helper to validate branch exists.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COUNT(*) FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, branch_id), - ) + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COUNT(*) FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) - count = cursor.fetchone()[0] - if count == 0: - raise ValueError(f"Branch '{branch_id}' does not exist") + count = cursor.fetchone()[0] + if count == 0: + raise ValueError(f"Branch '{branch_id}' does not exist") await asyncio.to_thread(_validate_branch) @@ -721,9 +749,7 @@ async def delete_branch(self, branch_id: str, force: bool = False) -> None: def _delete_sync(): """Synchronous helper to delete branch and associated data.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: with closing(conn.cursor()) as cursor: # First verify the branch exists cursor.execute( @@ -784,37 +810,37 @@ async def list_branches(self) -> list[dict[str, Any]]: def _list_branches_sync(): """Synchronous helper to list all branches.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT - ms.branch_id, - COUNT(*) as message_count, - COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, - MIN(ms.created_at) as created_at - FROM message_structure ms - WHERE ms.session_id = ? - GROUP BY ms.branch_id - ORDER BY created_at - """, - (self.session_id,), - ) - - branches = [] - for row in cursor.fetchall(): - branch_id, msg_count, user_turns, created_at = row - branches.append( - { - "branch_id": branch_id, - "message_count": msg_count, - "user_turns": user_turns, - "is_current": branch_id == self._current_branch_id, - "created_at": created_at, - } + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + ms.branch_id, + COUNT(*) as message_count, + COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, + MIN(ms.created_at) as created_at + FROM message_structure ms + WHERE ms.session_id = ? + GROUP BY ms.branch_id + ORDER BY created_at + """, + (self.session_id,), ) - return branches + branches = [] + for row in cursor.fetchall(): + branch_id, msg_count, user_turns, created_at = row + branches.append( + { + "branch_id": branch_id, + "message_count": msg_count, + "user_turns": user_turns, + "is_current": branch_id == self._current_branch_id, + "created_at": created_at, + } + ) + + return branches return await asyncio.to_thread(_list_branches_sync) @@ -828,9 +854,7 @@ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_numbe def _copy_sync(): """Synchronous helper to copy messages to new branch.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: with closing(conn.cursor()) as cursor: # Get all messages before the branch point cursor.execute( @@ -921,41 +945,43 @@ async def get_conversation_turns(self, branch_id: str | None = None) -> list[dic def _get_turns_sync(): """Synchronous helper to get conversation turns.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - f""" - SELECT - ms.branch_turn_number, - am.message_data, - ms.created_at - FROM message_structure ms - JOIN {self.messages_table} am ON ms.message_id = am.id - WHERE ms.session_id = ? AND ms.branch_id = ? - AND ms.message_type = 'user' - ORDER BY ms.branch_turn_number - """, - (self.session_id, branch_id), - ) + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + f""" + SELECT + ms.branch_turn_number, + am.message_data, + ms.created_at + FROM message_structure ms + JOIN {self.messages_table} am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.message_type = 'user' + ORDER BY ms.branch_turn_number + """, + (self.session_id, branch_id), + ) - turns = [] - for row in cursor.fetchall(): - turn_num, message_data, created_at = row - try: - content = json.loads(message_data).get("content", "") - turns.append( - { - "turn": turn_num, - "content": content[:100] + "..." if len(content) > 100 else content, - "full_content": content, - "timestamp": created_at, - "can_branch": True, - } - ) - except (json.JSONDecodeError, AttributeError): - continue + turns = [] + for row in cursor.fetchall(): + turn_num, message_data, created_at = row + try: + content = json.loads(message_data).get("content", "") + turns.append( + { + "turn": turn_num, + "content": ( + content[:100] + "..." if len(content) > 100 else content + ), + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + except (json.JSONDecodeError, AttributeError): + continue - return turns + return turns return await asyncio.to_thread(_get_turns_sync) @@ -976,42 +1002,42 @@ async def find_turns_by_content( def _search_sync(): """Synchronous helper to search turns by content.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - f""" - SELECT - ms.branch_turn_number, - am.message_data, - ms.created_at - FROM message_structure ms - JOIN {self.messages_table} am ON ms.message_id = am.id - WHERE ms.session_id = ? AND ms.branch_id = ? - AND ms.message_type = 'user' - AND am.message_data LIKE ? - ORDER BY ms.branch_turn_number - """, - (self.session_id, branch_id, f"%{search_term}%"), - ) + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + f""" + SELECT + ms.branch_turn_number, + am.message_data, + ms.created_at + FROM message_structure ms + JOIN {self.messages_table} am ON ms.message_id = am.id + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.message_type = 'user' + AND am.message_data LIKE ? + ORDER BY ms.branch_turn_number + """, + (self.session_id, branch_id, f"%{search_term}%"), + ) - matches = [] - for row in cursor.fetchall(): - turn_num, message_data, created_at = row - try: - content = json.loads(message_data).get("content", "") - matches.append( - { - "turn": turn_num, - "content": content, - "full_content": content, - "timestamp": created_at, - "can_branch": True, - } - ) - except (json.JSONDecodeError, AttributeError): - continue + matches = [] + for row in cursor.fetchall(): + turn_num, message_data, created_at = row + try: + content = json.loads(message_data).get("content", "") + matches.append( + { + "turn": turn_num, + "content": content, + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + except (json.JSONDecodeError, AttributeError): + continue - return matches + return matches return await asyncio.to_thread(_search_sync) @@ -1031,25 +1057,25 @@ async def get_conversation_by_turns( def _get_conversation_sync(): """Synchronous helper to get conversation by turns.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT user_turn_number, message_type, tool_name - FROM message_structure - WHERE session_id = ? AND branch_id = ? - ORDER BY sequence_number - """, - (self.session_id, branch_id), - ) + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT user_turn_number, message_type, tool_name + FROM message_structure + WHERE session_id = ? AND branch_id = ? + ORDER BY sequence_number + """, + (self.session_id, branch_id), + ) - turns: dict[int, list[dict[str, str | None]]] = {} - for row in cursor.fetchall(): - turn_num, msg_type, tool_name = row - if turn_num not in turns: - turns[turn_num] = [] - turns[turn_num].append({"type": msg_type, "tool_name": tool_name}) - return turns + turns: dict[int, list[dict[str, str | None]]] = {} + for row in cursor.fetchall(): + turn_num, msg_type, tool_name = row + if turn_num not in turns: + turns[turn_num] = [] + turns[turn_num].append({"type": msg_type, "tool_name": tool_name}) + return turns return await asyncio.to_thread(_get_conversation_sync) @@ -1067,47 +1093,47 @@ async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, def _get_tool_usage_sync(): """Synchronous helper to get tool usage statistics.""" - conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT tool_name, SUM(usage_count), user_turn_number - FROM ( - SELECT tool_name, 1 AS usage_count, user_turn_number - FROM message_structure - WHERE session_id = ? AND branch_id = ? AND message_type IN ( - 'tool_call', 'function_call', 'computer_call', 'file_search_call', - 'web_search_call', 'code_interpreter_call', 'tool_search_call', - 'custom_tool_call', 'mcp_call', 'mcp_approval_request' - ) - - UNION ALL + with self._locked_connection() as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT tool_name, SUM(usage_count), user_turn_number + FROM ( + SELECT tool_name, 1 AS usage_count, user_turn_number + FROM message_structure + WHERE session_id = ? AND branch_id = ? AND message_type IN ( + 'tool_call', 'function_call', 'computer_call', 'file_search_call', + 'web_search_call', 'code_interpreter_call', 'tool_search_call', + 'custom_tool_call', 'mcp_call', 'mcp_approval_request' + ) - SELECT ms.tool_name, 1 AS usage_count, ms.user_turn_number - FROM message_structure ms - WHERE ms.session_id = ? AND ms.branch_id = ? - AND ms.message_type = 'tool_search_output' - AND NOT EXISTS ( - SELECT 1 - FROM message_structure calls - WHERE calls.session_id = ms.session_id - AND calls.branch_id = ms.branch_id - AND calls.user_turn_number = ms.user_turn_number - AND calls.tool_name = ms.tool_name - AND calls.message_type = 'tool_search_call' - ) + UNION ALL + + SELECT ms.tool_name, 1 AS usage_count, ms.user_turn_number + FROM message_structure ms + WHERE ms.session_id = ? AND ms.branch_id = ? + AND ms.message_type = 'tool_search_output' + AND NOT EXISTS ( + SELECT 1 + FROM message_structure calls + WHERE calls.session_id = ms.session_id + AND calls.branch_id = ms.branch_id + AND calls.user_turn_number = ms.user_turn_number + AND calls.tool_name = ms.tool_name + AND calls.message_type = 'tool_search_call' + ) + ) + GROUP BY tool_name, user_turn_number + ORDER BY user_turn_number + """, + ( + self.session_id, + branch_id, + self.session_id, + branch_id, + ), ) - GROUP BY tool_name, user_turn_number - ORDER BY user_turn_number - """, - ( - self.session_id, - branch_id, - self.session_id, - branch_id, - ), - ) - return cursor.fetchall() + return cursor.fetchall() return await asyncio.to_thread(_get_tool_usage_sync) @@ -1123,9 +1149,7 @@ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int def _get_usage_sync(): """Synchronous helper to get session usage data.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: if branch_id: # Branch-specific usage query = """ @@ -1191,47 +1215,46 @@ async def get_turn_usage( def _get_turn_usage_sync(): """Synchronous helper to get turn usage statistics.""" - conn = self._get_connection() - - if user_turn_number is not None: - query = """ - SELECT requests, input_tokens, output_tokens, total_tokens, - input_tokens_details, output_tokens_details - FROM turn_usage - WHERE session_id = ? AND branch_id = ? AND user_turn_number = ? - """ - - with closing(conn.cursor()) as cursor: - cursor.execute(query, (self.session_id, branch_id, user_turn_number)) - row = cursor.fetchone() - - if row: - # Parse JSON details if present - input_details = None - output_details = None - - if row[4]: # input_tokens_details - try: - input_details = json.loads(row[4]) - except json.JSONDecodeError: - pass + with self._locked_connection() as conn: + if user_turn_number is not None: + query = """ + SELECT requests, input_tokens, output_tokens, total_tokens, + input_tokens_details, output_tokens_details + FROM turn_usage + WHERE session_id = ? AND branch_id = ? AND user_turn_number = ? + """ - if row[5]: # output_tokens_details - try: - output_details = json.loads(row[5]) - except json.JSONDecodeError: - pass + with closing(conn.cursor()) as cursor: + cursor.execute(query, (self.session_id, branch_id, user_turn_number)) + row = cursor.fetchone() + + if row: + # Parse JSON details if present + input_details = None + output_details = None + + if row[4]: # input_tokens_details + try: + input_details = json.loads(row[4]) + except json.JSONDecodeError: + pass + + if row[5]: # output_tokens_details + try: + output_details = json.loads(row[5]) + except json.JSONDecodeError: + pass + + return { + "requests": row[0], + "input_tokens": row[1], + "output_tokens": row[2], + "total_tokens": row[3], + "input_tokens_details": input_details, + "output_tokens_details": output_details, + } + return {} - return { - "requests": row[0], - "input_tokens": row[1], - "output_tokens": row[2], - "total_tokens": row[3], - "input_tokens_details": input_details, - "output_tokens_details": output_details, - } - return {} - else: query = """ SELECT user_turn_number, requests, input_tokens, output_tokens, total_tokens, input_tokens_details, output_tokens_details @@ -1287,9 +1310,7 @@ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: U def _update_sync(): """Synchronous helper to update turn usage data.""" - conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: # Serialize token details as JSON input_details_json = None output_details_json = None diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..d0ca2557a2 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -4,7 +4,10 @@ import json import sqlite3 import threading +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path +from typing import ClassVar from ..items import TResponseInputItem from .session import SessionABC @@ -20,6 +23,9 @@ class SQLiteSession(SessionABC): """ session_settings: SessionSettings | None = None + _file_locks: ClassVar[dict[Path, threading.RLock]] = {} + _file_lock_counts: ClassVar[dict[Path, int]] = {} + _file_locks_guard: ClassVar[threading.Lock] = threading.Lock() def __init__( self, @@ -46,21 +52,66 @@ def __init__( self.sessions_table = sessions_table self.messages_table = messages_table self._local = threading.local() - self._lock = threading.Lock() # For in-memory databases, we need a shared connection to avoid thread isolation # For file databases, we use thread-local connections for better concurrency self._is_memory_db = str(db_path) == ":memory:" + self._lock_path: Path | None = None + self._lock_released = False if self._is_memory_db: - self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) - self._shared_connection.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(self._shared_connection) + self._lock = threading.RLock() else: - # For file databases, initialize the schema once since it persists - init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - init_conn.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(init_conn) - init_conn.close() + self._lock_path, self._lock = self._acquire_file_lock(Path(self.db_path)) + + try: + if self._is_memory_db: + self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + with self._lock: + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + except Exception: + if self._lock_path is not None and not self._lock_released: + self._release_file_lock(self._lock_path) + self._lock_released = True + raise + + @classmethod + def _acquire_file_lock(cls, db_path: Path) -> tuple[Path, threading.RLock]: + """Return the path key and process-local lock for sessions sharing one SQLite file.""" + lock_path = db_path.expanduser().resolve() + with cls._file_locks_guard: + lock = cls._file_locks.get(lock_path) + if lock is None: + lock = threading.RLock() + cls._file_locks[lock_path] = lock + cls._file_lock_counts[lock_path] = 0 + cls._file_lock_counts[lock_path] += 1 + return lock_path, lock + + @classmethod + def _release_file_lock(cls, lock_path: Path) -> None: + """Drop the shared lock for a file-backed DB once the last session closes.""" + with cls._file_locks_guard: + ref_count = cls._file_lock_counts.get(lock_path) + if ref_count is None: + return + if ref_count <= 1: + cls._file_lock_counts.pop(lock_path, None) + cls._file_locks.pop(lock_path, None) + else: + cls._file_lock_counts[lock_path] = ref_count - 1 + + @contextmanager + def _locked_connection(self) -> Iterator[sqlite3.Connection]: + """Serialize sqlite3 access while each operation runs in a worker thread.""" + with self._lock: + yield self._get_connection() def _get_connection(self) -> sqlite3.Connection: """Get a database connection.""" @@ -114,6 +165,31 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() + def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem]) -> None: + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. @@ -127,8 +203,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: session_limit = resolve_session_limit(limit, self.session_settings) def _get_items_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: if session_limit is None: # Fetch all items in chronological order cursor = conn.execute( @@ -180,36 +255,8 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: return def _add_items_sync(): - conn = self._get_connection() - - with self._lock if self._is_memory_db else threading.Lock(): - # Ensure session exists - conn.execute( - f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) - """, - (self.session_id,), - ) - - # Add items - message_data = [(self.session_id, json.dumps(item)) for item in items] - conn.executemany( - f""" - INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) - """, - message_data, - ) - - # Update session timestamp - conn.execute( - f""" - UPDATE {self.sessions_table} - SET updated_at = CURRENT_TIMESTAMP - WHERE session_id = ? - """, - (self.session_id,), - ) - + with self._locked_connection() as conn: + self._insert_items(conn, items) conn.commit() await asyncio.to_thread(_add_items_sync) @@ -222,8 +269,7 @@ async def pop_item(self) -> TResponseInputItem | None: """ def _pop_item_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: # Use DELETE with RETURNING to atomically delete and return the most recent item cursor = conn.execute( f""" @@ -259,8 +305,7 @@ async def clear_session(self) -> None: """Clear all items for this session.""" def _clear_session_sync(): - conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._locked_connection() as conn: conn.execute( f"DELETE FROM {self.messages_table} WHERE session_id = ?", (self.session_id,), @@ -281,3 +326,6 @@ def close(self) -> None: else: if hasattr(self._local, "connection"): self._local.connection.close() + if self._lock_path is not None and not self._lock_released: + self._release_file_lock(self._lock_path) + self._lock_released = True diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index 7be57e6b00..b61c5235f0 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -1,5 +1,9 @@ """Tests for AdvancedSQLiteSession functionality.""" +import asyncio +import json +import tempfile +from pathlib import Path from typing import Any, Optional, cast import pytest @@ -1343,3 +1347,52 @@ async def test_runner_with_session_settings_override(agent: Agent): assert len(history_items) == 2 session.close() + + +async def test_concurrent_add_items_preserves_message_structure_for_file_db(): + """Concurrent add_items calls should keep agent_messages and message_structure aligned.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "advanced_concurrent.db" + session = AdvancedSQLiteSession( + session_id="advanced_concurrent", + db_path=db_path, + create_tables=True, + ) + + async def add_batch(worker_id: int) -> list[str]: + contents = [f"worker-{worker_id}-message-{index}" for index in range(10)] + await session.add_items([{"role": "user", "content": content} for content in contents]) + return contents + + expected_batches = await asyncio.gather(*(add_batch(worker_id) for worker_id in range(8))) + expected_contents = {content for batch in expected_batches for content in batch} + + retrieved_items = await session.get_items() + retrieved_contents = { + content + for item in retrieved_items + for content in [item.get("content")] + if isinstance(content, str) + } + + assert retrieved_contents == expected_contents + assert len(retrieved_items) == len(expected_contents) + + with session._locked_connection() as conn: + rows = conn.execute( + f""" + SELECT m.message_data + FROM {session.messages_table} m + JOIN message_structure s ON s.message_id = m.id + WHERE m.session_id = ? + ORDER BY s.sequence_number ASC + """, + (session.session_id,), + ).fetchall() + + structured_contents = {json.loads(message_data).get("content") for (message_data,) in rows} + + assert structured_contents == expected_contents + assert len(rows) == len(expected_contents) + + session.close() diff --git a/tests/test_session.py b/tests/test_session.py index aaa80ec7aa..8ede928812 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -537,6 +537,36 @@ def add_item(item): session.close() +@pytest.mark.asyncio +async def test_sqlite_session_file_lock_is_shared_across_instances(): + """File-backed sessions pointing at the same DB path should reuse one process-local lock.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_shared_lock.db" + lock_path = db_path.resolve() + + session_1 = SQLiteSession("session_1", db_path) + session_2 = SQLiteSession("session_2", db_path) + + assert session_1._lock is session_2._lock + assert SQLiteSession._file_lock_counts[lock_path] == 2 + + await asyncio.gather( + session_1.add_items([{"role": "user", "content": "session_1"}]), + session_2.add_items([{"role": "user", "content": "session_2"}]), + ) + + assert [item.get("content") for item in await session_1.get_items()] == ["session_1"] + assert [item.get("content") for item in await session_2.get_items()] == ["session_2"] + + session_1.close() + assert SQLiteSession._file_lock_counts[lock_path] == 1 + assert lock_path in SQLiteSession._file_locks + + session_2.close() + assert lock_path not in SQLiteSession._file_lock_counts + assert lock_path not in SQLiteSession._file_locks + + @pytest.mark.asyncio async def test_session_add_items_exception_propagates_in_streamed(): """Test that exceptions from session.add_items are properly propagated