diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..d2409e4fcd 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -487,30 +487,26 @@ def _cleanup_sync(): def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int: with closing(conn.cursor()) as cursor: - # Find messages without structure metadata. + # Delete branch-orphaned messages in a single set-based statement so + # we never bind one parameter per orphan id. The previous IN (?, ?, ...) + # form raised "too many SQL variables" once the cleanup batch grew past + # SQLITE_MAX_VARIABLE_NUMBER (999 on older SQLite builds, 32766 on newer). cursor.execute( f""" - SELECT am.id - FROM {self.messages_table} am - LEFT JOIN message_structure ms ON am.id = ms.message_id - WHERE am.session_id = ? AND ms.message_id IS NULL + DELETE FROM {self.messages_table} + WHERE id IN ( + SELECT am.id + FROM {self.messages_table} am + LEFT JOIN message_structure ms ON am.id = ms.message_id + WHERE am.session_id = ? AND ms.message_id IS NULL + ) """, (self.session_id,), ) - orphaned_ids = [row[0] for row in cursor.fetchall()] - - if not orphaned_ids: - return 0 - - placeholders = ",".join("?" * len(orphaned_ids)) - cursor.execute( - f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})", - orphaned_ids, - ) - deleted_count = cursor.rowcount - self._logger.info(f"Cleaned up {deleted_count} orphaned messages") + if deleted_count > 0: + self._logger.info(f"Cleaned up {deleted_count} orphaned messages") return deleted_count def _classify_message_type(self, item: TResponseInputItem) -> str: @@ -786,14 +782,21 @@ def _delete_sync(): structure_deleted = cursor.rowcount + # Drop any base messages that this branch was the only + # reference for. Without this step, branch-only rows become + # invisible to advanced reads (which join through + # `message_structure`) but linger in the base table. + orphans_deleted = self._cleanup_orphaned_messages_sync(conn) + conn.commit() - return usage_deleted, structure_deleted + return usage_deleted, structure_deleted, orphans_deleted - usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + usage_deleted, structure_deleted, orphans_deleted = await asyncio.to_thread(_delete_sync) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + f"Deleted branch '{branch_id}': {structure_deleted} message entries, " + f"{usage_deleted} usage entries, {orphans_deleted} orphaned base rows" ) async def list_branches(self) -> list[dict[str, Any]]: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..ff6c53820d 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -2,6 +2,7 @@ import asyncio import json +import sqlite3 import tempfile from pathlib import Path from typing import Any, cast @@ -1422,3 +1423,106 @@ async def test_output_tokens_details_persisted_when_input_details_missing(): assert turn_usage["output_tokens_details"] == {"reasoning_tokens": 42} assert turn_usage["input_tokens_details"] is None session.close() + + +async def test_delete_branch_removes_branch_only_base_messages(): + """Regression for #3346. + + `delete_branch` used to remove only the `turn_usage` and `message_structure` + rows for a branch, leaving the underlying messages in the base table when + they were only referenced by that branch. Those rows became invisible to + advanced reads (which join through `message_structure`) but still padded + the on-disk database. After the fix, branch-only messages are dropped along + with their structure rows, while messages shared with another branch are + preserved. + """ + session = AdvancedSQLiteSession( + session_id="delete_branch_orphan_repro", + create_tables=True, + ) + + try: + await session.add_items( + [ + {"role": "user", "content": "main question"}, + {"role": "assistant", "content": "main answer"}, + ] + ) + + await session.create_branch_from_turn(1, "branch_only") + await session.add_items( + [ + {"role": "user", "content": "branch-only question"}, + {"role": "assistant", "content": "branch-only answer"}, + ] + ) + + await session.delete_branch("branch_only", force=True) + + with session._locked_connection() as conn: + message_rows = conn.execute( + f"SELECT id FROM {session.messages_table} WHERE session_id = ? ORDER BY id", + (session.session_id,), + ).fetchall() + structure_rows = conn.execute( + "SELECT branch_id FROM message_structure WHERE session_id = ?", + (session.session_id,), + ).fetchall() + + # Main-branch messages remain (shared between branches were copied at + # `create_branch_from_turn`, so the two new branch-only rows were the + # only references for ids 3 and 4 — both should be gone now). + assert [row[0] for row in message_rows] == [1, 2] + assert all(row[0] == "main" for row in structure_rows) + assert await session.get_items(branch_id="main") == [ + {"role": "user", "content": "main question"}, + {"role": "assistant", "content": "main answer"}, + ] + finally: + session.close() + + +@pytest.mark.skipif( + not hasattr(sqlite3.Connection, "setlimit"), + reason="sqlite3.Connection.setlimit requires Python 3.11+", +) +async def test_cleanup_orphaned_messages_exceeds_sqlite_variable_limit(): + """Regression for codex review on #3380. + + The previous orphan cleanup built `DELETE ... WHERE id IN (?, ?, ...)` with + one bound parameter per orphan id, so deleting a branch with more orphans + than `SQLITE_MAX_VARIABLE_NUMBER` raised `OperationalError: too many SQL + variables`. We drop the per-connection variable limit so the failure + reproduces with a small number of rows; the set-based DELETE binds only + `session_id`, so the orphan count no longer matters. + """ + session = AdvancedSQLiteSession( + session_id="cleanup_under_low_limit", + create_tables=True, + ) + + try: + SQLITE_LIMIT_VARIABLE_NUMBER = 9 + low_limit = 5 + + with session._locked_connection() as conn: + conn.setlimit(SQLITE_LIMIT_VARIABLE_NUMBER, low_limit) + + orphan_count = low_limit * 4 + conn.executemany( + f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)", + [(session.session_id, "{}") for _ in range(orphan_count)], + ) + conn.commit() + + deleted = await session._cleanup_orphaned_messages() + assert deleted == orphan_count + + with session._locked_connection() as conn: + (remaining,) = conn.execute( + f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?", + (session.session_id,), + ).fetchone() + assert remaining == 0 + finally: + session.close()