Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self._init_structure_tables()
self._current_branch_id = "main"
self._logger = logger or logging.getLogger(__name__)
# Note: _file_db_lock is inherited from SQLiteSession base class for thread safety

def _init_structure_tables(self):
"""Add structure and usage tracking tables.
Expand Down Expand Up @@ -158,8 +159,8 @@ async def get_items(
def _get_all_items_sync():
"""Synchronous helper to get all items for a branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the shared instance lock inherited from base class for disk-based DBs
with self._lock if self._is_memory_db else self._file_db_lock:
with closing(conn.cursor()) as cursor:
if session_limit is None:
cursor.execute(
Expand Down Expand Up @@ -203,8 +204,8 @@ def _get_all_items_sync():
def _get_items_sync():
"""Synchronous helper to get items for a specific branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
with closing(conn.cursor()) as cursor:
# Get message IDs in correct order for this branch
if session_limit is None:
Expand Down Expand Up @@ -345,8 +346,8 @@ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None
def _add_structure_sync():
"""Synchronous helper to add structure metadata to database."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
# Get the IDs of messages we just inserted, in order
with closing(conn.cursor()) as cursor:
cursor.execute(
Expand Down Expand Up @@ -451,8 +452,8 @@ async def _cleanup_orphaned_messages(self) -> int:
def _cleanup_sync():
"""Synchronous helper to cleanup orphaned messages."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
with closing(conn.cursor()) as cursor:
# Find messages without structure metadata
cursor.execute(
Expand Down Expand Up @@ -722,8 +723,8 @@ async def delete_branch(self, branch_id: str, force: bool = False) -> None:
def _delete_sync():
"""Synchronous helper to delete branch and associated data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
with closing(conn.cursor()) as cursor:
# First verify the branch exists
cursor.execute(
Expand Down Expand Up @@ -829,8 +830,8 @@ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_numbe
def _copy_sync():
"""Synchronous helper to copy messages to new branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
with closing(conn.cursor()) as cursor:
# Get all messages before the branch point
cursor.execute(
Expand Down Expand Up @@ -1124,8 +1125,8 @@ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int
def _get_usage_sync():
"""Synchronous helper to get session usage data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
if branch_id:
# Branch-specific usage
query = """
Expand Down Expand Up @@ -1288,8 +1289,8 @@ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: U
def _update_sync():
"""Synchronous helper to update turn usage data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
# Use the instance lock for disk-based DBs to ensure consistent locking
with self._lock if self._is_memory_db else self._file_db_lock:
# Serialize token details as JSON
input_details_json = None
output_details_json = None
Expand Down
10 changes: 6 additions & 4 deletions src/agents/memory/sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(
self.messages_table = messages_table
self._local = threading.local()
self._lock = threading.Lock()
# Shared lock for disk-based databases to ensure thread safety across base and derived classes
self._file_db_lock = threading.Lock()

# For in-memory databases, we need a shared connection to avoid thread isolation
# For file databases, we use thread-local connections for better concurrency
Expand Down Expand Up @@ -128,7 +130,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

def _get_items_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
with self._lock if self._is_memory_db else self._file_db_lock:
if session_limit is None:
# Fetch all items in chronological order
cursor = conn.execute(
Expand Down Expand Up @@ -182,7 +184,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
def _add_items_sync():
conn = self._get_connection()

with self._lock if self._is_memory_db else threading.Lock():
with self._lock if self._is_memory_db else self._file_db_lock:
# Ensure session exists
conn.execute(
f"""
Expand Down Expand Up @@ -223,7 +225,7 @@ async def pop_item(self) -> TResponseInputItem | None:

def _pop_item_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
with self._lock if self._is_memory_db else self._file_db_lock:
# Use DELETE with RETURNING to atomically delete and return the most recent item
cursor = conn.execute(
f"""
Expand Down Expand Up @@ -260,7 +262,7 @@ async def clear_session(self) -> None:

def _clear_session_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
with self._lock if self._is_memory_db else self._file_db_lock:
conn.execute(
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
(self.session_id,),
Expand Down
113 changes: 113 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,3 +1343,116 @@ async def test_runner_with_session_settings_override(agent: Agent):
assert len(history_items) == 2

session.close()


async def test_concurrent_file_db_access_regression():
"""Regression test: concurrent access to file-backed DB should be thread-safe.

This test verifies that the shared _file_db_lock prevents race conditions
when multiple coroutines access the same file-backed AdvancedSQLiteSession.
Previously, threading.Lock() was created fresh in each method call,
making concurrent access unsafe.
"""
import tempfile
import asyncio

with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
db_path = tmp.name

try:
session = AdvancedSQLiteSession(
session_id="concurrent_test",
db_path=db_path,
create_tables=True,
)

async def add_items_task(task_id: int, count: int) -> int:
"""Task that adds items and returns the count added."""
items_added = 0
for i in range(count):
item: TResponseInputItem = {
"role": "user",
"content": f"Task {task_id} message {i}",
}
await session.add_items([item])
items_added += 1
return items_added

tasks = [
add_items_task(1, 10),
add_items_task(2, 10),
add_items_task(3, 10),
]

results = await asyncio.gather(*tasks)

assert sum(results) == 30, f"Expected 30 items added, got {sum(results)}"

retrieved = await session.get_items()
assert len(retrieved) == 30, f"Expected 30 items retrieved, got {len(retrieved)}"

contents = [item.get("content", "") for item in retrieved]
for task_id in range(1, 4):
for i in range(10):
expected = f"Task {task_id} message {i}"
assert expected in contents, f"Missing item: {expected}"

session.close()
finally:
# Clean up temp file
import os
try:
os.unlink(db_path)
except FileNotFoundError:
pass


async def test_concurrent_mixed_operations_file_db():
"""Test concurrent mixed operations (add, get, branch) on file-backed DB."""
import tempfile
import asyncio

with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
db_path = tmp.name

try:
session = AdvancedSQLiteSession(
session_id="concurrent_mixed_test",
db_path=db_path,
create_tables=True,
)

async def mixed_operations(task_id: int) -> tuple[int, int]:
"""Perform mixed operations and return (added, retrieved)."""
added = 0
for i in range(5):
await session.add_items([{"role": "user", "content": f"T{task_id}-{i}"}])
added += 1

items = await session.get_items()
retrieved = len(items)

await session.add_items([{"role": "user", "content": f"T{task_id}-extra"}])
added += 1

return added, retrieved

tasks = [mixed_operations(i) for i in range(4)]
results = await asyncio.gather(*tasks)

final_items = await session.get_items()
assert len(final_items) == 24, f"Expected 24 items, got {len(final_items)}" # 4 tasks * (5 + 1)

all_contents = [item.get("content", "") for item in final_items]
for task_id in range(4):
for i in range(5):
assert f"T{task_id}-{i}" in all_contents
assert f"T{task_id}-extra" in all_contents

session.close()
finally:
import os
try:
os.unlink(db_path)
except FileNotFoundError:
pass