Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,15 +785,17 @@ def _delete_sync():
)

structure_deleted = cursor.rowcount
messages_deleted = self._cleanup_orphaned_messages_sync(conn)

conn.commit()

return usage_deleted, structure_deleted
return usage_deleted, structure_deleted, messages_deleted

usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
usage_deleted, structure_deleted, messages_deleted = await asyncio.to_thread(_delete_sync)

self._logger.info(
f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
f"Deleted branch '{branch_id}': {structure_deleted} message entries, "
f"{usage_deleted} usage entries, {messages_deleted} orphaned messages"
)

async def list_branches(self) -> list[dict[str, Any]]:
Expand Down
49 changes: 49 additions & 0 deletions tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,55 @@ async def test_branch_deletion_with_force():
session.close()


async def test_delete_branch_removes_branch_only_messages():
session_id = "delete_branch_orphans_test"
session = AdvancedSQLiteSession(session_id=session_id, create_tables=True)

await session.add_items(
[
{"role": "user", "content": "Main question"},
{"role": "assistant", "content": "Main answer"},
]
)
await session.create_branch_from_turn(1, "branch_only")
await session.add_items(
[
{"role": "user", "content": "Branch-only question"},
{"role": "assistant", "content": "Branch-only answer"},
]
)

await session.delete_branch("branch_only", force=True)

with session._locked_connection() as conn:
message_rows = conn.execute(
f"""
SELECT message_data
FROM {session.messages_table}
WHERE session_id = ?
ORDER BY id
""",
(session.session_id,),
).fetchall()
structure_rows = conn.execute(
"""
SELECT branch_id, message_id
FROM message_structure
WHERE session_id = ?
ORDER BY message_id
""",
(session.session_id,),
).fetchall()

assert [json.loads(row[0])["content"] for row in message_rows] == [
"Main question",
"Main answer",
]
assert {row[0] for row in structure_rows} == {"main"}

session.close()


async def test_get_items_with_parameters():
"""Test get_items with new parameters (include_inactive, branch_id)."""
session_id = "get_items_params_test"
Expand Down