diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a352918211..a056ed0b39 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -46,6 +46,7 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime from sqlalchemy.types import PickleType @@ -335,11 +336,20 @@ def from_event(cls, session: Session, event: Event) -> StorageEvent: ) if event.custom_metadata: storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: + + if hasattr(event, "usage_metadata") and event.usage_metadata is not None: + try: + usage_meta = event.usage_metadata + if hasattr(usage_meta, "model_dump"): + storage_event.usage_metadata = usage_meta.model_dump( + exclude_none=False, mode="json" + ) + except Exception as e: + logger.error( + f"[StorageEvent.from_event] Error while saving usage_metadata: {e}" + ) + + if hasattr(event, "citation_metadata") and event.citation_metadata: storage_event.citation_metadata = event.citation_metadata.model_dump( exclude_none=True, mode="json" ) @@ -727,7 +737,15 @@ async def append_event(self, session: Session, event: Event) -> Event: else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time - sql_session.add(StorageEvent.from_event(session, event)) + storage_event = StorageEvent.from_event(session, event) + + sql_session.add(storage_event) + + # Forçar SQLAlchemy a detectar mudanças em campos MutableDict/DynamicJSON + if storage_event.usage_metadata is not None: + flag_modified(storage_event, "usage_metadata") + if storage_event.citation_metadata is not None: + flag_modified(storage_event, "citation_metadata") await sql_session.commit() await sql_session.refresh(storage_session) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 46d8616619..9b98b13764 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +import json +import logging from typing import Any from typing import TYPE_CHECKING @@ -23,7 +25,9 @@ from . import _automatic_function_calling_util from ..agents.common_configs import AgentRefConfig +from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..sessions import Session from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool @@ -34,6 +38,8 @@ if TYPE_CHECKING: from ..agents.base_agent import BaseAgent +logger = logging.getLogger(__name__) + class AgentTool(BaseTool): """A tool that wraps an agent. @@ -136,7 +142,11 @@ async def run_async( else: content = types.Content( role='user', - parts=[types.Part.from_text(text=args['request'])], + parts=[ + types.Part.from_text( + text=str(args) if isinstance(args, str) else json.dumps(args) + ) + ], ) invocation_context = tool_context._invocation_context parent_app_name = ( @@ -161,40 +171,85 @@ async def run_async( state_dict = { k: v for k, v in tool_context.state.to_dict().items() - if not k.startswith('_adk') # Filter out adk internal states + if not k.startswith('_adk') } - session = await runner.session_service.create_session( + sub_agent_session = await runner.session_service.create_session( app_name=child_app_name, user_id=tool_context._invocation_context.user_id, state=state_dict, ) - last_content = None + # Collect all text chunks from streaming response instead of just last content + chunks: list[str] = [] + sub_agent_events = [] + + def iter_text_parts(parts): + """Safely iterate over parts and extract text, skipping None values.""" + for p in parts or []: + if hasattr(p, 'text') and p.text is not None: + yield p.text + async with Aclosing( runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content + user_id=sub_agent_session.user_id, + session_id=sub_agent_session.id, + new_message=content, ) ) as agen: async for event in agen: - # Forward state delta to parent session. if event.actions.state_delta: tool_context.state.update(event.actions.state_delta) + # Collect text chunks from all events, not just the last one if event.content: - last_content = event.content + chunks.extend(iter_text_parts(event.content.parts)) + sub_agent_events.append(event) + + if sub_agent_events and hasattr(tool_context, '_invocation_context'): + main_session = tool_context._invocation_context.session + if main_session and hasattr( + tool_context._invocation_context, 'session_service' + ): + session_service = tool_context._invocation_context.session_service + parent_agent_name = ( + tool_context._invocation_context.agent.name + if hasattr(tool_context._invocation_context, 'agent') + else 'root_agent' + ) + + for sub_event in sub_agent_events: + + try: + if hasattr(sub_event, 'branch') and sub_event.branch: + event_branch = sub_event.branch + else: + event_branch = f'{parent_agent_name}.{self.agent.name}' + + copied_event = sub_event.model_copy(update={'branch': event_branch}) + + await session_service.append_event(main_session, copied_event) + except Exception as e: + logger.warning( + "Error copying sub-agent event from '%s' to main session: %s", + self.agent.name, + e, + ) # Clean up runner resources (especially MCP sessions) # to avoid "Attempted to exit cancel scope in a different task" errors await runner.close() - if not last_content: + # Merge all collected chunks into final text + merged_text = ''.join(chunks) + + if not merged_text: return '' - merged_text = '\n'.join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: tool_result = self.agent.output_schema.model_validate_json( merged_text ).model_dump(exclude_none=True) else: tool_result = merged_text + return tool_result @override diff --git a/tests/unittests/tools/test_agent_tool_new_features.py b/tests/unittests/tools/test_agent_tool_new_features.py new file mode 100644 index 0000000000..31dec231dc --- /dev/null +++ b/tests/unittests/tools/test_agent_tool_new_features.py @@ -0,0 +1,214 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for new features added to AgentTool and DatabaseSessionService.""" + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import Agent +from google.adk.agents.run_config import RunConfig +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.agent_tool import AgentTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from google.genai.types import Part +from pytest import mark + +from .. import testing_utils + +function_call_no_schema = Part.from_function_call( + name="tool_agent", args={"request": "test1"} +) + + +@mark.asyncio +async def test_agent_tool_handles_dict_args(): + """Test that AgentTool handles dictionary arguments correctly (non-request key).""" + + mock_model = testing_utils.MockModel.create( + responses=["response to dict arg"] + ) + + tool_agent = Agent( + name="tool_agent", + model=mock_model, + ) + + # Create invocation context + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + + invocation_context = InvocationContext( + invocation_id="test_invocation", + agent=tool_agent, + session=session, + session_service=session_service, + artifact_service=InMemoryArtifactService(), + memory_service=InMemoryMemoryService(), + plugin_manager=PluginManager(plugins=[]), + run_config=RunConfig(), + ) + + tool_context = ToolContext(invocation_context=invocation_context) + agent_tool = AgentTool(agent=tool_agent) + + # Test with dict argument that doesn't have 'request' key + result = await agent_tool.run_async( + args={"custom_key": "custom_value"}, tool_context=tool_context + ) + + assert result is not None + assert "response to dict arg" in str(result) + + +@mark.asyncio +async def test_database_session_service_persists_usage_metadata(): + """Test that DatabaseSessionService correctly persists usage_metadata with flag_modified.""" + import os + import tempfile + + from google.adk.sessions.database_session_service import DatabaseSessionService + + # Create temporary database + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + db_url = f"sqlite+aiosqlite:///{temp_db.name}" + + try: + service = DatabaseSessionService(db_url) + + # Create session + session = await service.create_session( + app_name="test_app", user_id="user123" + ) + + # Create event with usage_metadata + event = Event( + id="evt1", + invocation_id="inv1", + author="model", + actions=EventActions(), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=50, + total_token_count=150, + ), + ) + + # Persist event + await service.append_event(session, event) + + # Retrieve session and verify usage_metadata was persisted + retrieved_session = await service.get_session( + app_name="test_app", user_id="user123", session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) > 0 + + # Find the event with usage_metadata + found_usage_metadata = False + for evt in retrieved_session.events: + if evt.usage_metadata is not None: + assert evt.usage_metadata.total_token_count == 150 + assert evt.usage_metadata.prompt_token_count == 100 + assert evt.usage_metadata.candidates_token_count == 50 + found_usage_metadata = True + break + + assert found_usage_metadata, "usage_metadata was not persisted correctly" + + finally: + # Cleanup + if os.path.exists(temp_db.name): + os.unlink(temp_db.name) + + +@mark.asyncio +async def test_database_session_service_persists_citation_metadata(): + """Test that DatabaseSessionService correctly persists citation_metadata.""" + import os + import tempfile + + from google.adk.sessions.database_session_service import DatabaseSessionService + + # Create temporary database + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + db_url = f"sqlite+aiosqlite:///{temp_db.name}" + + try: + service = DatabaseSessionService(db_url) + + # Create session + session = await service.create_session( + app_name="test_app", user_id="user123" + ) + + # Create event with citation_metadata + event = Event( + id="evt1", + invocation_id="inv1", + author="model", + actions=EventActions(), + citation_metadata=types.CitationMetadata( + citations=[ + types.Citation( + start_index=0, + end_index=10, + uri="https://example.com", + title="Example Source", + ) + ] + ), + ) + + # Persist event + await service.append_event(session, event) + + # Retrieve session and verify citation_metadata was persisted + retrieved_session = await service.get_session( + app_name="test_app", user_id="user123", session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) > 0 + + # Find the event with citation_metadata + found_citation_metadata = False + for evt in retrieved_session.events: + if evt.citation_metadata is not None: + assert len(evt.citation_metadata.citations) == 1 + assert evt.citation_metadata.citations[0].uri == "https://example.com" + assert evt.citation_metadata.citations[0].title == "Example Source" + found_citation_metadata = True + break + + assert ( + found_citation_metadata + ), "citation_metadata was not persisted correctly" + + finally: + # Cleanup + if os.path.exists(temp_db.name): + os.unlink(temp_db.name) + + +# Made with Bob