From cac8b21903512b78b2296adc1b89ca3bcaf9ca76 Mon Sep 17 00:00:00 2001 From: "Xql.dev" <205849173+KirobotDev@users.noreply.github.com> Date: Fri, 3 Apr 2026 03:28:43 +0200 Subject: [PATCH] Fix: shared threading.Lock for file-backed SQLite sessions - Add shared _file_db_lock in base SQLiteSession class - Replace fresh threading.Lock() calls with shared instance lock - Ensure AdvancedSQLiteSession inherits the same lock - Add regression tests for concurrent file DB access --- .../memory/advanced_sqlite_session.py | 33 ++--- src/agents/memory/sqlite_session.py | 10 +- .../memory/test_advanced_sqlite_session.py | 113 ++++++++++++++++++ 3 files changed, 136 insertions(+), 20 deletions(-) diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index fcb4743cb3..4b1e657d37 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -49,6 +49,7 @@ def __init__( self._init_structure_tables() self._current_branch_id = "main" self._logger = logger or logging.getLogger(__name__) + # Note: _file_db_lock is inherited from SQLiteSession base class for thread safety def _init_structure_tables(self): """Add structure and usage tracking tables. @@ -158,8 +159,8 @@ async def get_items( def _get_all_items_sync(): """Synchronous helper to get all items for a branch.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the shared instance lock inherited from base class for disk-based DBs + with self._lock if self._is_memory_db else self._file_db_lock: with closing(conn.cursor()) as cursor: if session_limit is None: cursor.execute( @@ -203,8 +204,8 @@ def _get_all_items_sync(): def _get_items_sync(): """Synchronous helper to get items for a specific branch.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: with closing(conn.cursor()) as cursor: # Get message IDs in correct order for this branch if session_limit is None: @@ -345,8 +346,8 @@ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None def _add_structure_sync(): """Synchronous helper to add structure metadata to database.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: # Get the IDs of messages we just inserted, in order with closing(conn.cursor()) as cursor: cursor.execute( @@ -451,8 +452,8 @@ async def _cleanup_orphaned_messages(self) -> int: def _cleanup_sync(): """Synchronous helper to cleanup orphaned messages.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: with closing(conn.cursor()) as cursor: # Find messages without structure metadata cursor.execute( @@ -722,8 +723,8 @@ async def delete_branch(self, branch_id: str, force: bool = False) -> None: def _delete_sync(): """Synchronous helper to delete branch and associated data.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: with closing(conn.cursor()) as cursor: # First verify the branch exists cursor.execute( @@ -829,8 +830,8 @@ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_numbe def _copy_sync(): """Synchronous helper to copy messages to new branch.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: with closing(conn.cursor()) as cursor: # Get all messages before the branch point cursor.execute( @@ -1124,8 +1125,8 @@ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int def _get_usage_sync(): """Synchronous helper to get session usage data.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: if branch_id: # Branch-specific usage query = """ @@ -1288,8 +1289,8 @@ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: U def _update_sync(): """Synchronous helper to update turn usage data.""" conn = self._get_connection() - # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + # Use the instance lock for disk-based DBs to ensure consistent locking + with self._lock if self._is_memory_db else self._file_db_lock: # Serialize token details as JSON input_details_json = None output_details_json = None diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..2d2cf291dc 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -47,6 +47,8 @@ def __init__( self.messages_table = messages_table self._local = threading.local() self._lock = threading.Lock() + # Shared lock for disk-based databases to ensure thread safety across base and derived classes + self._file_db_lock = threading.Lock() # For in-memory databases, we need a shared connection to avoid thread isolation # For file databases, we use thread-local connections for better concurrency @@ -128,7 +130,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: def _get_items_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock if self._is_memory_db else self._file_db_lock: if session_limit is None: # Fetch all items in chronological order cursor = conn.execute( @@ -182,7 +184,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: def _add_items_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock if self._is_memory_db else self._file_db_lock: # Ensure session exists conn.execute( f""" @@ -223,7 +225,7 @@ async def pop_item(self) -> TResponseInputItem | None: def _pop_item_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock if self._is_memory_db else self._file_db_lock: # Use DELETE with RETURNING to atomically delete and return the most recent item cursor = conn.execute( f""" @@ -260,7 +262,7 @@ async def clear_session(self) -> None: def _clear_session_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock if self._is_memory_db else self._file_db_lock: conn.execute( f"DELETE FROM {self.messages_table} WHERE session_id = ?", (self.session_id,), diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index 7be57e6b00..bd6079f6e5 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -1343,3 +1343,116 @@ async def test_runner_with_session_settings_override(agent: Agent): assert len(history_items) == 2 session.close() + + +async def test_concurrent_file_db_access_regression(): + """Regression test: concurrent access to file-backed DB should be thread-safe. + + This test verifies that the shared _file_db_lock prevents race conditions + when multiple coroutines access the same file-backed AdvancedSQLiteSession. + Previously, threading.Lock() was created fresh in each method call, + making concurrent access unsafe. + """ + import tempfile + import asyncio + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + + try: + session = AdvancedSQLiteSession( + session_id="concurrent_test", + db_path=db_path, + create_tables=True, + ) + + async def add_items_task(task_id: int, count: int) -> int: + """Task that adds items and returns the count added.""" + items_added = 0 + for i in range(count): + item: TResponseInputItem = { + "role": "user", + "content": f"Task {task_id} message {i}", + } + await session.add_items([item]) + items_added += 1 + return items_added + + tasks = [ + add_items_task(1, 10), + add_items_task(2, 10), + add_items_task(3, 10), + ] + + results = await asyncio.gather(*tasks) + + assert sum(results) == 30, f"Expected 30 items added, got {sum(results)}" + + retrieved = await session.get_items() + assert len(retrieved) == 30, f"Expected 30 items retrieved, got {len(retrieved)}" + + contents = [item.get("content", "") for item in retrieved] + for task_id in range(1, 4): + for i in range(10): + expected = f"Task {task_id} message {i}" + assert expected in contents, f"Missing item: {expected}" + + session.close() + finally: + # Clean up temp file + import os + try: + os.unlink(db_path) + except FileNotFoundError: + pass + + +async def test_concurrent_mixed_operations_file_db(): + """Test concurrent mixed operations (add, get, branch) on file-backed DB.""" + import tempfile + import asyncio + + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + + try: + session = AdvancedSQLiteSession( + session_id="concurrent_mixed_test", + db_path=db_path, + create_tables=True, + ) + + async def mixed_operations(task_id: int) -> tuple[int, int]: + """Perform mixed operations and return (added, retrieved).""" + added = 0 + for i in range(5): + await session.add_items([{"role": "user", "content": f"T{task_id}-{i}"}]) + added += 1 + + items = await session.get_items() + retrieved = len(items) + + await session.add_items([{"role": "user", "content": f"T{task_id}-extra"}]) + added += 1 + + return added, retrieved + + tasks = [mixed_operations(i) for i in range(4)] + results = await asyncio.gather(*tasks) + + final_items = await session.get_items() + assert len(final_items) == 24, f"Expected 24 items, got {len(final_items)}" # 4 tasks * (5 + 1) + + all_contents = [item.get("content", "") for item in final_items] + for task_id in range(4): + for i in range(5): + assert f"T{task_id}-{i}" in all_contents + assert f"T{task_id}-extra" in all_contents + + session.close() + finally: + import os + try: + os.unlink(db_path) + except FileNotFoundError: + pass