diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index e990015c76..3a81d2ae02 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -31,6 +31,7 @@ from fastapi import FastAPI from fastapi import HTTPException from fastapi import Query +from fastapi import Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from fastapi.responses import StreamingResponse @@ -210,6 +211,13 @@ class RunEvalRequest(common.BaseModel): eval_metrics: list[EvalMetric] +class UpdateMemoryRequest(common.BaseModel): + """Request to add a session to the memory service.""" + + session_id: str + """The ID of the session to add to memory.""" + + class RunEvalResult(common.BaseModel): eval_set_file: str eval_set_id: str @@ -1144,6 +1152,41 @@ async def delete_artifact( filename=artifact_name, ) + @app.patch("/apps/{app_name}/users/{user_id}/memory") + async def patch_memory( + app_name: str, user_id: str, update_memory_request: UpdateMemoryRequest + ) -> None: + """Adds all events from a given session to the memory service. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + update_memory_request: The memory request for the update + + Raises: + HTTPException: If the memory service is not configured or the request is invalid. + """ + if not self.memory_service: + raise HTTPException( + status_code=400, detail="Memory service is not configured." + ) + if ( + update_memory_request is None + or update_memory_request.session_id is None + ): + raise HTTPException( + status_code=400, detail="Update memory request is invalid." + ) + + session = await self.session_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=update_memory_request.session_id, + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + await self.memory_service.add_session_to_memory(session) + @app.post("/run", response_model_exclude_none=True) async def run_agent(req: RunAgentRequest) -> list[Event]: session = await self.session_service.get_session( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d1e8dcabc2..4bcf2f119f 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -22,6 +22,7 @@ import time from typing import Any from typing import Optional +from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch @@ -344,7 +345,7 @@ async def delete_artifact(self, app_name, user_id, session_id, filename): @pytest.fixture def mock_memory_service(): """Create a mock memory service.""" - return MagicMock() + return AsyncMock() @pytest.fixture @@ -939,5 +940,18 @@ def test_a2a_disabled_by_default(test_app): logger.info("A2A disabled by default test passed") +def test_patch_memory(test_app, create_test_session, mock_memory_service): + """Test adding a session to memory.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/memory" + payload = {"session_id": info["session_id"]} + response = test_app.patch(url, json=payload) + + # Verify the response + assert response.status_code == 200 + mock_memory_service.add_session_to_memory.assert_called_once() + logger.info("Add session to memory test completed successfully") + + if __name__ == "__main__": pytest.main(["-xvs", __file__])