From 49085df86ea8456fa7686e73e9fb0efa6673ace0 Mon Sep 17 00:00:00 2001 From: Lucas Meireles Date: Mon, 24 Nov 2025 09:48:32 -0300 Subject: [PATCH 1/5] Fix empty response from root agent after agent as tool response and persist usage_metadata in session store --- .../adk/sessions/database_session_service.py | 26 +++++-- src/google/adk/tools/agent_tool.py | 76 ++++++++++++++++--- 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a352918211..680d0ed5c6 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -46,6 +46,7 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime from sqlalchemy.types import PickleType @@ -335,11 +336,16 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent: ) if event.custom_metadata: storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: + + if hasattr(event, 'usage_metadata') and event.usage_metadata is not None: + try: + usage_meta = event.usage_metadata + if hasattr(usage_meta, 'model_dump'): + storage_event.usage_metadata = usage_meta.model_dump(exclude_none=False, mode="json") + except Exception as e: + logger.error(f"[StorageEvent.from_event] Erro ao salvar usage_metadata: {e}") + + if hasattr(event, 'citation_metadata') and event.citation_metadata: storage_event.citation_metadata = event.citation_metadata.model_dump( exclude_none=True, mode="json" ) @@ -727,7 +733,15 @@ async def append_event(self, session: Session, event: Event) -> Event: else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time - sql_session.add(StorageEvent.from_event(session, event)) + storage_event = StorageEvent.from_event(session, event) + + sql_session.add(storage_event) + + # Forçar SQLAlchemy a detectar mudanças em campos MutableDict/DynamicJSON + if storage_event.usage_metadata is not None: + flag_modified(storage_event, "usage_metadata") + if storage_event.citation_metadata is not None: + flag_modified(storage_event, "citation_metadata") await sql_session.commit() await sql_session.refresh(storage_session) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 46d8616619..5541a400f1 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -136,7 +136,10 @@ async def run_async( else: content = types.Content( role='user', - parts=[types.Part.from_text(text=args['request'])], + parts=[ + types.Part.from_text( + text=str(args) if isinstance(args, str) else __import__('json').dumps(args)) + ], ) invocation_context = tool_context._invocation_context parent_app_name = ( @@ -161,40 +164,95 @@ async def run_async( state_dict = { k: v for k, v in tool_context.state.to_dict().items() - if not k.startswith('_adk') # Filter out adk internal states + if not k.startswith('_adk') } - session = await runner.session_service.create_session( + sub_agent_session = await runner.session_service.create_session( app_name=child_app_name, user_id=tool_context._invocation_context.user_id, state=state_dict, ) - last_content = None + # Collect all text chunks from streaming response instead of just last content + chunks: list[str] = [] + sub_agent_events = [] + + def iter_text_parts(parts): + """Safely iterate over parts and extract text, skipping None values.""" + for p in parts or []: + if hasattr(p, "text") and p.text is not None: + yield p.text + async with Aclosing( runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content + user_id=sub_agent_session.user_id, session_id=sub_agent_session.id, new_message=content ) ) as agen: async for event in agen: - # Forward state delta to parent session. if event.actions.state_delta: tool_context.state.update(event.actions.state_delta) + # Collect text chunks from all events, not just the last one if event.content: - last_content = event.content + chunks.extend(iter_text_parts(event.content.parts)) + sub_agent_events.append(event) + + if sub_agent_events and hasattr(tool_context, '_invocation_context'): + from ..sessions import Session + from ..events.event import Event + + main_session = tool_context._invocation_context.session + if main_session and hasattr(tool_context._invocation_context, 'session_service'): + session_service = tool_context._invocation_context.session_service + parent_agent_name = tool_context._invocation_context.agent.name if hasattr(tool_context._invocation_context, 'agent') else "root_agent" + + for sub_event in sub_agent_events: + + try: + if hasattr(sub_event, 'branch') and sub_event.branch: + event_branch = sub_event.branch + else: + event_branch = f"{parent_agent_name}.{self.agent.name}" + + copied_event = Event( + id=sub_event.id, + invocation_id=sub_event.invocation_id, + author=sub_event.author, + branch=event_branch, + actions=sub_event.actions, + content=sub_event.content, + long_running_tool_ids=sub_event.long_running_tool_ids, + partial=sub_event.partial, + turn_complete=sub_event.turn_complete, + error_code=sub_event.error_code, + error_message=sub_event.error_message, + interrupted=sub_event.interrupted, + grounding_metadata=sub_event.grounding_metadata, + custom_metadata=sub_event.custom_metadata, + usage_metadata=sub_event.usage_metadata, + citation_metadata=getattr(sub_event, 'citation_metadata', None), + ) + + await session_service.append_event(main_session, copied_event) + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.warning(f"Erro ao copiar evento do sub-agente {self.agent.name} para sessão principal: {e}") # Clean up runner resources (especially MCP sessions) # to avoid "Attempted to exit cancel scope in a different task" errors await runner.close() - if not last_content: + # Merge all collected chunks into final text + merged_text = "".join(chunks) + + if not merged_text: return '' - merged_text = '\n'.join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: tool_result = self.agent.output_schema.model_validate_json( merged_text ).model_dump(exclude_none=True) else: tool_result = merged_text + return tool_result @override From 0e983dd2aade9a36cec2abe4b517fc61b0e78d16 Mon Sep 17 00:00:00 2001 From: Lucas Meireles Date: Mon, 24 Nov 2025 10:05:25 -0300 Subject: [PATCH 2/5] Add tests for usage_metadata persistence and empty response handling --- .../tools/test_agent_tool_new_features.py | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tests/unittests/tools/test_agent_tool_new_features.py diff --git a/tests/unittests/tools/test_agent_tool_new_features.py b/tests/unittests/tools/test_agent_tool_new_features.py new file mode 100644 index 0000000000..35c66a4d18 --- /dev/null +++ b/tests/unittests/tools/test_agent_tool_new_features.py @@ -0,0 +1,217 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for new features added to AgentTool and DatabaseSessionService.""" + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import Agent +from google.adk.agents.run_config import RunConfig +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.agent_tool import AgentTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from google.genai.types import Part +from pytest import mark + +from .. import testing_utils + +function_call_no_schema = Part.from_function_call( + name='tool_agent', args={'request': 'test1'} +) + + +@mark.asyncio +async def test_agent_tool_handles_dict_args(): + """Test that AgentTool handles dictionary arguments correctly (non-request key).""" + + mock_model = testing_utils.MockModel.create( + responses=['response to dict arg'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + # Create invocation context + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', + user_id='test_user' + ) + + invocation_context = InvocationContext( + invocation_id='test_invocation', + agent=tool_agent, + session=session, + session_service=session_service, + artifact_service=InMemoryArtifactService(), + memory_service=InMemoryMemoryService(), + plugin_manager=PluginManager(plugins=[]), + run_config=RunConfig(), + ) + + tool_context = ToolContext(invocation_context=invocation_context) + agent_tool = AgentTool(agent=tool_agent) + + # Test with dict argument that doesn't have 'request' key + result = await agent_tool.run_async( + args={"custom_key": "custom_value"}, + tool_context=tool_context + ) + + assert result is not None + assert 'response to dict arg' in str(result) + + +@mark.asyncio +async def test_database_session_service_persists_usage_metadata(): + """Test that DatabaseSessionService correctly persists usage_metadata with flag_modified.""" + from google.adk.sessions.database_session_service import DatabaseSessionService + import tempfile + import os + + # Create temporary database + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + temp_db.close() + db_url = f"sqlite+aiosqlite:///{temp_db.name}" + + try: + service = DatabaseSessionService(db_url) + + # Create session + session = await service.create_session( + app_name="test_app", + user_id="user123" + ) + + # Create event with usage_metadata + event = Event( + id="evt1", + invocation_id="inv1", + author="model", + actions=EventActions(), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=50, + total_token_count=150 + ) + ) + + # Persist event + await service.append_event(session, event) + + # Retrieve session and verify usage_metadata was persisted + retrieved_session = await service.get_session( + app_name="test_app", + user_id="user123", + session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) > 0 + + # Find the event with usage_metadata + found_usage_metadata = False + for evt in retrieved_session.events: + if evt.usage_metadata is not None: + assert evt.usage_metadata.total_token_count == 150 + assert evt.usage_metadata.prompt_token_count == 100 + assert evt.usage_metadata.candidates_token_count == 50 + found_usage_metadata = True + break + + assert found_usage_metadata, "usage_metadata was not persisted correctly" + + finally: + # Cleanup + if os.path.exists(temp_db.name): + os.unlink(temp_db.name) + + +@mark.asyncio +async def test_database_session_service_persists_citation_metadata(): + """Test that DatabaseSessionService correctly persists citation_metadata.""" + from google.adk.sessions.database_session_service import DatabaseSessionService + import tempfile + import os + + # Create temporary database + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + temp_db.close() + db_url = f"sqlite+aiosqlite:///{temp_db.name}" + + try: + service = DatabaseSessionService(db_url) + + # Create session + session = await service.create_session( + app_name="test_app", + user_id="user123" + ) + + # Create event with citation_metadata + event = Event( + id="evt1", + invocation_id="inv1", + author="model", + actions=EventActions(), + citation_metadata=types.CitationMetadata( + citations=[ + types.Citation( + start_index=0, + end_index=10, + uri="https://example.com", + title="Example Source" + ) + ] + ) + ) + + # Persist event + await service.append_event(session, event) + + # Retrieve session and verify citation_metadata was persisted + retrieved_session = await service.get_session( + app_name="test_app", + user_id="user123", + session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) > 0 + + # Find the event with citation_metadata + found_citation_metadata = False + for evt in retrieved_session.events: + if evt.citation_metadata is not None: + assert len(evt.citation_metadata.citations) == 1 + assert evt.citation_metadata.citations[0].uri == "https://example.com" + assert evt.citation_metadata.citations[0].title == "Example Source" + found_citation_metadata = True + break + + assert found_citation_metadata, "citation_metadata was not persisted correctly" + + finally: + # Cleanup + if os.path.exists(temp_db.name): + os.unlink(temp_db.name) + +# Made with Bob From 80e5b33b8b4d995e6322f3c44796ff5793cb8be8 Mon Sep 17 00:00:00 2001 From: Lucas Martins Meireles <72349802+marttinslucas@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:35:45 -0300 Subject: [PATCH 3/5] Update src/google/adk/sessions/database_session_service.py change log message language Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/sessions/database_session_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 680d0ed5c6..e93d705592 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -343,7 +343,7 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent: if hasattr(usage_meta, 'model_dump'): storage_event.usage_metadata = usage_meta.model_dump(exclude_none=False, mode="json") except Exception as e: - logger.error(f"[StorageEvent.from_event] Erro ao salvar usage_metadata: {e}") + logger.error(f"[StorageEvent.from_event] Error while saving usage_metadata: {e}") if hasattr(event, 'citation_metadata') and event.citation_metadata: storage_event.citation_metadata = event.citation_metadata.model_dump( From 0b63466611a5180dd1f4352220c5c084e4d79c75 Mon Sep 17 00:00:00 2001 From: Lucas Martins Meireles <72349802+marttinslucas@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:37:22 -0300 Subject: [PATCH 4/5] Update src/google/adk/tools/agent_tool.py improve import of json Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/google/adk/tools/agent_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 5541a400f1..520ac1209a 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -138,7 +138,7 @@ async def run_async( role='user', parts=[ types.Part.from_text( - text=str(args) if isinstance(args, str) else __import__('json').dumps(args)) + text=str(args) if isinstance(args, str) else json.dumps(args)) ], ) invocation_context = tool_context._invocation_context From 23d3d3e44c4b67c28369ce68879b6eb6e7901c48 Mon Sep 17 00:00:00 2001 From: Lucas Meireles Date: Mon, 24 Nov 2025 10:48:36 -0300 Subject: [PATCH 5/5] Apply review fixes: English logs, global imports, model_copy usage --- .../adk/sessions/database_session_service.py | 22 ++-- src/google/adk/tools/agent_tool.py | 73 ++++++------ .../tools/test_agent_tool_new_features.py | 111 +++++++++--------- 3 files changed, 102 insertions(+), 104 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index e93d705592..a056ed0b39 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -336,16 +336,20 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent: ) if event.custom_metadata: storage_event.custom_metadata = event.custom_metadata - - if hasattr(event, 'usage_metadata') and event.usage_metadata is not None: + + if hasattr(event, "usage_metadata") and event.usage_metadata is not None: try: usage_meta = event.usage_metadata - if hasattr(usage_meta, 'model_dump'): - storage_event.usage_metadata = usage_meta.model_dump(exclude_none=False, mode="json") + if hasattr(usage_meta, "model_dump"): + storage_event.usage_metadata = usage_meta.model_dump( + exclude_none=False, mode="json" + ) except Exception as e: - logger.error(f"[StorageEvent.from_event] Error while saving usage_metadata: {e}") - - if hasattr(event, 'citation_metadata') and event.citation_metadata: + logger.error( + f"[StorageEvent.from_event] Error while saving usage_metadata: {e}" + ) + + if hasattr(event, "citation_metadata") and event.citation_metadata: storage_event.citation_metadata = event.citation_metadata.model_dump( exclude_none=True, mode="json" ) @@ -734,9 +738,9 @@ async def append_event(self, session: Session, event: Event) -> Event: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time storage_event = StorageEvent.from_event(session, event) - + sql_session.add(storage_event) - + # Forçar SQLAlchemy a detectar mudanças em campos MutableDict/DynamicJSON if storage_event.usage_metadata is not None: flag_modified(storage_event, "usage_metadata") diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 520ac1209a..9b98b13764 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +import json +import logging from typing import Any from typing import TYPE_CHECKING @@ -23,7 +25,9 @@ from . import _automatic_function_calling_util from ..agents.common_configs import AgentRefConfig +from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..sessions import Session from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool @@ -34,6 +38,8 @@ if TYPE_CHECKING: from ..agents.base_agent import BaseAgent +logger = logging.getLogger(__name__) + class AgentTool(BaseTool): """A tool that wraps an agent. @@ -138,7 +144,8 @@ async def run_async( role='user', parts=[ types.Part.from_text( - text=str(args) if isinstance(args, str) else json.dumps(args)) + text=str(args) if isinstance(args, str) else json.dumps(args) + ) ], ) invocation_context = tool_context._invocation_context @@ -175,16 +182,18 @@ async def run_async( # Collect all text chunks from streaming response instead of just last content chunks: list[str] = [] sub_agent_events = [] - + def iter_text_parts(parts): """Safely iterate over parts and extract text, skipping None values.""" for p in parts or []: - if hasattr(p, "text") and p.text is not None: + if hasattr(p, 'text') and p.text is not None: yield p.text - + async with Aclosing( runner.run_async( - user_id=sub_agent_session.user_id, session_id=sub_agent_session.id, new_message=content + user_id=sub_agent_session.user_id, + session_id=sub_agent_session.id, + new_message=content, ) ) as agen: async for event in agen: @@ -196,54 +205,42 @@ def iter_text_parts(parts): sub_agent_events.append(event) if sub_agent_events and hasattr(tool_context, '_invocation_context'): - from ..sessions import Session - from ..events.event import Event - main_session = tool_context._invocation_context.session - if main_session and hasattr(tool_context._invocation_context, 'session_service'): + if main_session and hasattr( + tool_context._invocation_context, 'session_service' + ): session_service = tool_context._invocation_context.session_service - parent_agent_name = tool_context._invocation_context.agent.name if hasattr(tool_context._invocation_context, 'agent') else "root_agent" - + parent_agent_name = ( + tool_context._invocation_context.agent.name + if hasattr(tool_context._invocation_context, 'agent') + else 'root_agent' + ) + for sub_event in sub_agent_events: try: if hasattr(sub_event, 'branch') and sub_event.branch: event_branch = sub_event.branch else: - event_branch = f"{parent_agent_name}.{self.agent.name}" - - copied_event = Event( - id=sub_event.id, - invocation_id=sub_event.invocation_id, - author=sub_event.author, - branch=event_branch, - actions=sub_event.actions, - content=sub_event.content, - long_running_tool_ids=sub_event.long_running_tool_ids, - partial=sub_event.partial, - turn_complete=sub_event.turn_complete, - error_code=sub_event.error_code, - error_message=sub_event.error_message, - interrupted=sub_event.interrupted, - grounding_metadata=sub_event.grounding_metadata, - custom_metadata=sub_event.custom_metadata, - usage_metadata=sub_event.usage_metadata, - citation_metadata=getattr(sub_event, 'citation_metadata', None), - ) - + event_branch = f'{parent_agent_name}.{self.agent.name}' + + copied_event = sub_event.model_copy(update={'branch': event_branch}) + await session_service.append_event(main_session, copied_event) except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.warning(f"Erro ao copiar evento do sub-agente {self.agent.name} para sessão principal: {e}") + logger.warning( + "Error copying sub-agent event from '%s' to main session: %s", + self.agent.name, + e, + ) # Clean up runner resources (especially MCP sessions) # to avoid "Attempted to exit cancel scope in a different task" errors await runner.close() # Merge all collected chunks into final text - merged_text = "".join(chunks) - + merged_text = ''.join(chunks) + if not merged_text: return '' if isinstance(self.agent, LlmAgent) and self.agent.output_schema: @@ -252,7 +249,7 @@ def iter_text_parts(parts): ).model_dump(exclude_none=True) else: tool_result = merged_text - + return tool_result @override diff --git a/tests/unittests/tools/test_agent_tool_new_features.py b/tests/unittests/tools/test_agent_tool_new_features.py index 35c66a4d18..31dec231dc 100644 --- a/tests/unittests/tools/test_agent_tool_new_features.py +++ b/tests/unittests/tools/test_agent_tool_new_features.py @@ -32,32 +32,31 @@ from .. import testing_utils function_call_no_schema = Part.from_function_call( - name='tool_agent', args={'request': 'test1'} + name="tool_agent", args={"request": "test1"} ) @mark.asyncio async def test_agent_tool_handles_dict_args(): """Test that AgentTool handles dictionary arguments correctly (non-request key).""" - + mock_model = testing_utils.MockModel.create( - responses=['response to dict arg'] + responses=["response to dict arg"] ) - + tool_agent = Agent( - name='tool_agent', + name="tool_agent", model=mock_model, ) - + # Create invocation context session_service = InMemorySessionService() session = await session_service.create_session( - app_name='test_app', - user_id='test_user' + app_name="test_app", user_id="test_user" ) - + invocation_context = InvocationContext( - invocation_id='test_invocation', + invocation_id="test_invocation", agent=tool_agent, session=session, session_service=session_service, @@ -66,41 +65,40 @@ async def test_agent_tool_handles_dict_args(): plugin_manager=PluginManager(plugins=[]), run_config=RunConfig(), ) - + tool_context = ToolContext(invocation_context=invocation_context) agent_tool = AgentTool(agent=tool_agent) - + # Test with dict argument that doesn't have 'request' key result = await agent_tool.run_async( - args={"custom_key": "custom_value"}, - tool_context=tool_context + args={"custom_key": "custom_value"}, tool_context=tool_context ) - + assert result is not None - assert 'response to dict arg' in str(result) + assert "response to dict arg" in str(result) @mark.asyncio async def test_database_session_service_persists_usage_metadata(): """Test that DatabaseSessionService correctly persists usage_metadata with flag_modified.""" - from google.adk.sessions.database_session_service import DatabaseSessionService - import tempfile import os - + import tempfile + + from google.adk.sessions.database_session_service import DatabaseSessionService + # Create temporary database - temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") temp_db.close() db_url = f"sqlite+aiosqlite:///{temp_db.name}" - + try: service = DatabaseSessionService(db_url) - + # Create session session = await service.create_session( - app_name="test_app", - user_id="user123" + app_name="test_app", user_id="user123" ) - + # Create event with usage_metadata event = Event( id="evt1", @@ -110,23 +108,21 @@ async def test_database_session_service_persists_usage_metadata(): usage_metadata=types.GenerateContentResponseUsageMetadata( prompt_token_count=100, candidates_token_count=50, - total_token_count=150 - ) + total_token_count=150, + ), ) - + # Persist event await service.append_event(session, event) - + # Retrieve session and verify usage_metadata was persisted retrieved_session = await service.get_session( - app_name="test_app", - user_id="user123", - session_id=session.id + app_name="test_app", user_id="user123", session_id=session.id ) - + assert retrieved_session is not None assert len(retrieved_session.events) > 0 - + # Find the event with usage_metadata found_usage_metadata = False for evt in retrieved_session.events: @@ -136,9 +132,9 @@ async def test_database_session_service_persists_usage_metadata(): assert evt.usage_metadata.candidates_token_count == 50 found_usage_metadata = True break - + assert found_usage_metadata, "usage_metadata was not persisted correctly" - + finally: # Cleanup if os.path.exists(temp_db.name): @@ -148,24 +144,24 @@ async def test_database_session_service_persists_usage_metadata(): @mark.asyncio async def test_database_session_service_persists_citation_metadata(): """Test that DatabaseSessionService correctly persists citation_metadata.""" - from google.adk.sessions.database_session_service import DatabaseSessionService - import tempfile import os - + import tempfile + + from google.adk.sessions.database_session_service import DatabaseSessionService + # Create temporary database - temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") temp_db.close() db_url = f"sqlite+aiosqlite:///{temp_db.name}" - + try: service = DatabaseSessionService(db_url) - + # Create session session = await service.create_session( - app_name="test_app", - user_id="user123" + app_name="test_app", user_id="user123" ) - + # Create event with citation_metadata event = Event( id="evt1", @@ -178,25 +174,23 @@ async def test_database_session_service_persists_citation_metadata(): start_index=0, end_index=10, uri="https://example.com", - title="Example Source" + title="Example Source", ) ] - ) + ), ) - + # Persist event await service.append_event(session, event) - + # Retrieve session and verify citation_metadata was persisted retrieved_session = await service.get_session( - app_name="test_app", - user_id="user123", - session_id=session.id + app_name="test_app", user_id="user123", session_id=session.id ) - + assert retrieved_session is not None assert len(retrieved_session.events) > 0 - + # Find the event with citation_metadata found_citation_metadata = False for evt in retrieved_session.events: @@ -206,12 +200,15 @@ async def test_database_session_service_persists_citation_metadata(): assert evt.citation_metadata.citations[0].title == "Example Source" found_citation_metadata = True break - - assert found_citation_metadata, "citation_metadata was not persisted correctly" - + + assert ( + found_citation_metadata + ), "citation_metadata was not persisted correctly" + finally: # Cleanup if os.path.exists(temp_db.name): os.unlink(temp_db.name) + # Made with Bob