diff --git a/contributing/samples/redis_agent_memory/README.md b/contributing/samples/redis_agent_memory/README.md new file mode 100644 index 0000000..7ade099 --- /dev/null +++ b/contributing/samples/redis_agent_memory/README.md @@ -0,0 +1,153 @@ +# Redis Agent Memory Sample + +This sample demonstrates the **complete two-tier memory architecture** using Redis Agent Memory Server with ADK: + +1. **RedisWorkingMemorySessionService** - Session management with auto-summarization +2. **RedisLongTermMemoryService** - Persistent long-term memory with semantic search + +## Architecture + +``` +┌────────────────────────────────────────────────────────────────┐ +│ ADK Agent │ +├──────────────────────────────┬─────────────────────────────────┤ +│ TIER 1: Working Memory │ TIER 2: Long-Term Memory │ +├──────────────────────────────┼─────────────────────────────────┤ +│ • Current session messages │ • Extracted facts & preferences │ +│ • Auto-summarization │ • Semantic vector search │ +│ • Context window management │ • Cross-session persistence │ +│ • TTL support │ • Recency-boosted retrieval │ +├──────────────────────────────┴─────────────────────────────────┤ +│ Agent Memory Server API │ +├────────────────────────────────────────────────────────────────┤ +│ Redis Stack │ +└────────────────────────────────────────────────────────────────┘ +``` + +## Example Flow + +``` +User Message + │ + ▼ +┌─────────────┐ store ┌──────────────────────┐ +│ ADK Agent │─────────────▶│ Working Memory │ +└─────────────┘ │ (current session) │ + │ └──────────┬───────────┘ + │ │ extract + │ search ▼ + │ ┌──────────────────────┐ + └──────────────────────▶│ Long-Term Memory │ + │ (all sessions) │ + └──────────────────────┘ +``` + +## Prerequisites + +- Python 3.10+ +- Docker (for Redis Stack and Agent Memory Server) + +## Setup + +### 1. Install Dependencies + +```bash +pip install "google-adk-community[redis-agent-memory]" +``` + +> **Important**: The server is NOT installed via pip - it's a separate service that must be running. The pip package only installs the client to communicate with it. + +### 2. Start Redis Stack + +```bash +docker run -d --name redis-stack -p 6379:6379 redis/redis-stack:latest +``` + +### 3. Start Agent Memory Server + +```bash +docker run -d --name agent-memory-server -p 8000:8000 \ + -e REDIS_URL=redis://host.docker.internal:6379 \ + -e OPENAI_API_KEY=your-openai-key \ + redislabs/agent-memory-server:latest \ + agent-memory api --host 0.0.0.0 --port 8000 --task-backend=asyncio +``` + +> **Note**: The memory server requires an OpenAI API key for embeddings by default. See the [Agent Memory Server docs](https://redis.github.io/agent-memory-server/) for alternative embedding providers. + +### 4. Verify Setup + +```bash +curl http://localhost:8000/health +``` + +### 5. Configure Environment + +Create `.env` in this directory: + +```bash +GOOGLE_API_KEY=your-google-api-key +REDIS_MEMORY_SERVER_URL=http://localhost:8000 +REDIS_MEMORY_NAMESPACE=adk_agent_memory +REDIS_MEMORY_EXTRACTION_STRATEGY=discrete +REDIS_MEMORY_CONTEXT_WINDOW=8000 +REDIS_MEMORY_RECENCY_BOOST=true +``` + +## Usage + +```bash +python main.py +``` + +Open http://localhost:8080 in your browser. + +## Test Conversation + +**Session 1** - Share information: +``` +User: Hi, I'm Nitin. I'm a Machine Learning Engineer working on ML projects. +User: I love coffee, especially Berliner Frühstück Coffee from Berliner Kaffeerösterei. +User: My favorite programming language is Python. +``` + +**Session 2** - Test memory recall: +``` +User: What do you remember about me? +User: What's my favorite coffee? +``` + +## Features + +| Feature | Working Memory (Tier 1) | Long-Term Memory (Tier 2) | +|---------|------------------------|---------------------------| +| Scope | Current session | All sessions | +| Auto-summarization | ✅ Yes | No | +| Semantic search | No | ✅ Yes | +| Fact extraction | Background | ✅ Persistent | +| TTL support | ✅ Yes | No | + +## Configuration + +| Variable | Default | Description | +|----------|---------|-------------| +| `REDIS_MEMORY_SERVER_URL` | `http://localhost:8000` | Memory server URL | +| `REDIS_MEMORY_NAMESPACE` | `adk_agent_memory` | Namespace for isolation | +| `REDIS_MEMORY_EXTRACTION_STRATEGY` | `discrete` | `discrete`, `summary`, `preferences` | +| `REDIS_MEMORY_CONTEXT_WINDOW` | `8000` | Max tokens before summarization | +| `REDIS_MEMORY_RECENCY_BOOST` | `true` | Boost recent memories in search | + +## Memory Server Configuration + +The Redis Agent Memory Server has important settings that affect memory extraction: + +| Setting | Default | Description | +|---------|---------|-------------| +| `EXTRACTION_DEBOUNCE_SECONDS` | `300` (5 min) | Time between extraction runs per session | +| `LONG_TERM_MEMORY` | `true` | Enable long-term memory storage | +| `ENABLE_DISCRETE_MEMORY_EXTRACTION` | `true` | Enable fact extraction from messages | + +**Note on Debouncing**: The memory server debounces extraction to avoid constantly re-extracting +from the same conversation. For testing, you can reduce `EXTRACTION_DEBOUNCE_SECONDS` to `5` in +the memory server's `.env` file. + diff --git a/contributing/samples/redis_agent_memory/main.py b/contributing/samples/redis_agent_memory/main.py new file mode 100644 index 0000000..748c997 --- /dev/null +++ b/contributing/samples/redis_agent_memory/main.py @@ -0,0 +1,133 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Full Redis Memory Sample: Working Memory Sessions + Long-Term Memory. + +This sample demonstrates using BOTH Redis Agent Memory Server services: +1. RedisWorkingMemorySessionService - For session management with auto-summarization +2. RedisLongTermMemoryService - For persistent long-term memory search + +This provides the complete two-tier memory architecture: +- Working Memory (Tier 1): Session state, messages, auto-summarization +- Long-Term Memory (Tier 2): Persistent facts, preferences, semantic search +""" + +import os +from urllib.parse import urlparse + +from dotenv import load_dotenv +from fastapi import FastAPI +from google.adk.cli.fast_api import get_fast_api_app +from google.adk.cli.service_registry import get_service_registry +import uvicorn + +from google.adk_community.memory import RedisLongTermMemoryService +from google.adk_community.memory import RedisLongTermMemoryServiceConfig +from google.adk_community.sessions import RedisWorkingMemorySessionService +from google.adk_community.sessions import RedisWorkingMemorySessionServiceConfig + +load_dotenv() + + +def parse_base_url(uri: str) -> str: + """Parse URI to extract base URL.""" + parsed = urlparse(uri) + location = parsed.netloc + parsed.path + return ( + location + if location.startswith(("http://", "https://")) + else f"http://{location}" + ) + + +def redis_session_factory(uri: str, **kwargs): + """Factory function for creating RedisWorkingMemorySessionService from URI.""" + base_url = parse_base_url(uri) + config = RedisWorkingMemorySessionServiceConfig( + api_base_url=base_url, + default_namespace=os.getenv("REDIS_MEMORY_NAMESPACE", "adk_agent_memory"), + model_name=os.getenv("REDIS_MEMORY_MODEL_NAME", "gpt-4o"), + context_window_max=int(os.getenv("REDIS_MEMORY_CONTEXT_WINDOW", "8000")), + extraction_strategy=os.getenv( + "REDIS_MEMORY_EXTRACTION_STRATEGY", "discrete" + ), + ) + return RedisWorkingMemorySessionService(config=config) + + +def redis_memory_factory(uri: str, **kwargs): + """Factory function for creating RedisLongTermMemoryService from URI.""" + base_url = parse_base_url(uri) + config = RedisLongTermMemoryServiceConfig( + api_base_url=base_url, + default_namespace=os.getenv("REDIS_MEMORY_NAMESPACE", "adk_agent_memory"), + extraction_strategy=os.getenv( + "REDIS_MEMORY_EXTRACTION_STRATEGY", "discrete" + ), + recency_boost=os.getenv("REDIS_MEMORY_RECENCY_BOOST", "true").lower() + == "true", + semantic_weight=float(os.getenv("REDIS_MEMORY_SEMANTIC_WEIGHT", "0.7")), + recency_weight=float(os.getenv("REDIS_MEMORY_RECENCY_WEIGHT", "0.3")), + ) + return RedisLongTermMemoryService(config=config) + + +# Register both service factories +registry = get_service_registry() +registry.register_session_service("redis-working-memory", redis_session_factory) +registry.register_memory_service("redis-long-term-memory", redis_memory_factory) + +# Build URIs from environment +server_url = ( + os.getenv("REDIS_MEMORY_SERVER_URL", "http://localhost:8000") + .replace("http://", "") + .replace("https://", "") +) +SESSION_SERVICE_URI = f"redis-working-memory://{server_url}" +MEMORY_SERVICE_URI = f"redis-long-term-memory://{server_url}" + +# Create the FastAPI app with both services +app: FastAPI = get_fast_api_app( + agents_dir=".", + session_service_uri=SESSION_SERVICE_URI, + memory_service_uri=MEMORY_SERVICE_URI, + web=True, +) + + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 8080)) + namespace = os.getenv("REDIS_MEMORY_NAMESPACE", "adk_agent_memory") + server = os.getenv("REDIS_MEMORY_SERVER_URL", "http://localhost:8000") + extraction = os.getenv("REDIS_MEMORY_EXTRACTION_STRATEGY", "discrete") + context_window = os.getenv("REDIS_MEMORY_CONTEXT_WINDOW", "8000") + + print(f""" +Starting Redis Agent Memory Sample +======================== +ADK Server: http://localhost:{port} +Memory Server: {server} +Namespace: {namespace} +Extraction Strategy: {extraction} +Context Window: {context_window} tokens + +Services: + - Session: RedisWorkingMemorySessionService (auto-summarization) + - Memory: RedisLongTermMemoryService (semantic search) + +Two-Tier Architecture: + Tier 1 (Working Memory): Session messages, state, auto-summarization + Tier 2 (Long-Term Memory): Extracted facts, preferences, semantic search +""") + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/contributing/samples/redis_agent_memory/redis_agent_memory_agent/__init__.py b/contributing/samples/redis_agent_memory/redis_agent_memory_agent/__init__.py new file mode 100644 index 0000000..c48963c --- /dev/null +++ b/contributing/samples/redis_agent_memory/redis_agent_memory_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/redis_agent_memory/redis_agent_memory_agent/agent.py b/contributing/samples/redis_agent_memory/redis_agent_memory_agent/agent.py new file mode 100644 index 0000000..791a5f5 --- /dev/null +++ b/contributing/samples/redis_agent_memory/redis_agent_memory_agent/agent.py @@ -0,0 +1,75 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent with full Redis memory: Working Memory + Long-Term Memory. + +This agent demonstrates the complete two-tier memory architecture: +- Working Memory: Automatic session summarization when context grows large +- Long-Term Memory: Persistent facts extracted and searchable across sessions +""" + +from datetime import datetime + +from google.adk import Agent +from google.adk.agents.callback_context import CallbackContext +from google.adk.tools import load_memory +from google.adk.tools import preload_memory + + +def before_agent(callback_context: CallbackContext): + """Update state before agent runs.""" + callback_context.state["_time"] = datetime.now().isoformat() + + +async def after_agent(callback_context: CallbackContext): + """Store session to long-term memory after agent completes.""" + # This triggers memory extraction to long-term memory + await callback_context.add_session_to_memory() + + +root_agent = Agent( + model="gemini-2.5-flash", + name="redis_agent_memory_agent", + description=( + "Agent with full two-tier Redis memory: working memory for sessions," + " long-term memory for persistence." + ), + before_agent_callback=before_agent, + after_agent_callback=after_agent, + instruction="""You are a helpful assistant with a powerful two-tier memory system. + +## Your Memory Capabilities + +1. **Working Memory** (automatic): Your current conversation is automatically managed. + When the conversation gets long, older messages are summarized to keep context efficient. + +2. **Long-Term Memory** (persistent): Important facts and preferences are automatically + extracted and stored. You can search this memory across sessions. + +## How to Use Memory + +- Use `load_memory` to search for information from past conversations +- When users share personal info (name, preferences, facts), acknowledge it - + it will be automatically saved to long-term memory +- If a search doesn't find results, try different keywords + +## Conversation Guidelines + +- Be conversational and remember details the user shares +- Reference past interactions when relevant +- Ask clarifying questions to learn more about the user + +Current time: {_time}""", + tools=[preload_memory, load_memory], +) diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..c62eb85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "google-adk", # Google ADK "httpx>=0.27.0, <1.0.0", # For OpenMemory service + "pydantic>=2.0, <3.0.0", # For data validation/models "redis>=5.0.0, <6.0.0", # Redis for session storage # go/keep-sorted end "orjson>=3.11.3", @@ -45,7 +46,9 @@ test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] - +redis-agent-memory = [ + "agent-memory-client>=0.13.0; python_version >= '3.10'", +] [tool.pyink] # Format py files following Google style-guide diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..98c0e33 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -15,4 +15,5 @@ from . import memory from . import sessions from . import version + __version__ = version.__version__ diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c..60610d3 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -16,9 +16,12 @@ from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig +from .redis_long_term_memory_service import RedisLongTermMemoryService +from .redis_long_term_memory_service import RedisLongTermMemoryServiceConfig __all__ = [ "OpenMemoryService", "OpenMemoryServiceConfig", + "RedisLongTermMemoryService", + "RedisLongTermMemoryServiceConfig", ] - diff --git a/src/google/adk_community/memory/redis_long_term_memory_service.py b/src/google/adk_community/memory/redis_long_term_memory_service.py new file mode 100644 index 0000000..4d0e264 --- /dev/null +++ b/src/google/adk_community/memory/redis_long_term_memory_service.py @@ -0,0 +1,302 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis Long-Term Memory Service for ADK. + +This module provides integration with the Redis Agent Memory Server, +offering production-grade long-term memory with automatic summarization, +topic/entity extraction, and recency-boosted search. + +Note: The classes were renamed from RedisAgentMemoryService to +RedisLongTermMemoryService to better reflect their purpose of managing +long-term memory via the Agent Memory Server. +""" + +from __future__ import annotations + +from functools import cached_property +import logging +from typing import Literal +from typing import Optional +from typing import TYPE_CHECKING + +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override + +from google.adk_community.memory.utils import extract_text_from_event + +if TYPE_CHECKING: + from google.adk.sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + + +class RedisLongTermMemoryServiceConfig(BaseModel): + """Configuration for Redis Long-Term Memory Service. + + Attributes: + api_base_url: Base URL of the Agent Memory Server. + timeout: HTTP request timeout in seconds. + default_namespace: Default namespace for memory operations. + search_top_k: Maximum number of memories to retrieve per search. + distance_threshold: Maximum distance threshold for search results (0.0-1.0). + recency_boost: Enable recency-aware re-ranking of search results. + semantic_weight: Weight for semantic similarity in recency boosting (0.0-1.0). + recency_weight: Weight for recency score in recency boosting (0.0-1.0). + freshness_weight: Weight for freshness component within recency score. + novelty_weight: Weight for novelty component within recency score. + half_life_last_access_days: Half-life in days for last_accessed decay. + half_life_created_days: Half-life in days for created_at decay. + extraction_strategy: Memory extraction strategy (discrete, summary, preferences, custom). + extraction_strategy_config: Additional configuration for the extraction strategy. + model_name: Model name for context window management and summarization. + context_window_max: Maximum context window tokens (overrides model default). + """ + + api_base_url: str = Field(default="http://localhost:8000") + timeout: float = Field(default=30.0, gt=0.0) + default_namespace: Optional[str] = None + search_top_k: int = Field(default=10, ge=1) + distance_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) + recency_boost: bool = True + semantic_weight: float = Field(default=0.8, ge=0.0, le=1.0) + recency_weight: float = Field(default=0.2, ge=0.0, le=1.0) + freshness_weight: float = Field(default=0.6, ge=0.0, le=1.0) + novelty_weight: float = Field(default=0.4, ge=0.0, le=1.0) + half_life_last_access_days: float = Field(default=7.0, gt=0.0) + half_life_created_days: float = Field(default=30.0, gt=0.0) + extraction_strategy: Literal[ + "discrete", "summary", "preferences", "custom" + ] = "discrete" + extraction_strategy_config: dict = Field(default_factory=dict) + model_name: Optional[str] = None + context_window_max: Optional[int] = Field(default=None, ge=1) + + +class RedisLongTermMemoryService(BaseMemoryService): + """Long-term memory service implementation using Redis Agent Memory Server. + + This service provides production-grade memory capabilities including: + - Two-tier memory architecture (working memory + long-term memory) + - Automatic memory extraction (semantic facts, episodic events, preferences) + - Topic and entity extraction + - Auto-summarization when context window is exceeded + - Recency-boosted semantic search + - Deduplication and memory compaction + - https://github.com/redis/agent-memory-server + Requires the `agent-memory-client` package to be installed. + + Example: + ```python + from google.adk_community.memory import ( + RedisLongTermMemoryService, + RedisLongTermMemoryServiceConfig, + ) + + config = RedisLongTermMemoryServiceConfig( + api_base_url="http://localhost:8000", + default_namespace="my_app", + recency_boost=True, + ) + memory_service = RedisLongTermMemoryService(config=config) + + # Use with ADK agent + agent = Agent( + name="my_agent", + memory_service=memory_service, + ) + ``` + """ + + def __init__(self, config: Optional[RedisLongTermMemoryServiceConfig] = None): + """Initialize the Redis Long-Term Memory Service. + + Args: + config: Configuration for the service. If None, uses defaults. + + Raises: + ImportError: If agent-memory-client package is not installed. + """ + self._config = config or RedisLongTermMemoryServiceConfig() + + @cached_property + def _client(self): + """Lazily initialize and return the MemoryAPIClient.""" + try: + from agent_memory_client import MemoryAPIClient + from agent_memory_client import MemoryClientConfig + except ImportError as e: + raise ImportError( + "agent-memory-client package is required for" + " RedisLongTermMemoryService. Install it with: pip install" + " agent-memory-client" + ) from e + + client_config = MemoryClientConfig( + base_url=self._config.api_base_url, + timeout=self._config.timeout, + default_namespace=self._config.default_namespace, + default_model_name=self._config.model_name, + default_context_window_max=self._config.context_window_max, + ) + return MemoryAPIClient(client_config) + + def _build_working_memory(self, session: "Session"): + """Convert ADK Session to WorkingMemory for the Agent Memory Server.""" + from agent_memory_client.models import MemoryMessage + from agent_memory_client.models import MemoryStrategyConfig + from agent_memory_client.models import WorkingMemory + + messages = [] + for event in session.events: + text = extract_text_from_event(event) + if not text: + continue + role = "user" if event.author == "user" else "assistant" + messages.append(MemoryMessage(role=role, content=text)) + + strategy_config = MemoryStrategyConfig( + strategy=self._config.extraction_strategy, + config=self._config.extraction_strategy_config, + ) + + return WorkingMemory( + session_id=session.id, + namespace=self._config.default_namespace or session.app_name, + user_id=session.user_id, + messages=messages, + long_term_memory_strategy=strategy_config, + ) + + @override + async def add_session_to_memory(self, session: "Session"): + """Add a session's events to the Agent Memory Server. + + Converts ADK Session events to WorkingMemory messages and stores them + in the Agent Memory Server. The server will automatically: + - Extract semantic and episodic memories based on the configured strategy + - Perform topic and entity extraction + - Summarize context when the token limit is exceeded + - Promote memories to long-term storage via background tasks + + Args: + session: The ADK Session containing events to store. + """ + try: + working_memory = self._build_working_memory(session) + + if not working_memory.messages: + logger.debug("No messages to store for session %s", session.id) + return + + response = await self._client.put_working_memory( + session_id=session.id, + memory=working_memory, + user_id=session.user_id, + ) + + logger.info( + "Stored %d messages for session %s (context: %.1f%% used)", + len(working_memory.messages), + session.id, + response.context_percentage_total_used or 0, + ) + + except Exception as e: + logger.error( + "Failed to add session %s to memory: %s", + session.id, + e, + ) + + def _build_recency_config(self): + """Build RecencyConfig from service configuration.""" + from agent_memory_client.models import RecencyConfig + + return RecencyConfig( + recency_boost=self._config.recency_boost, + semantic_weight=self._config.semantic_weight, + recency_weight=self._config.recency_weight, + freshness_weight=self._config.freshness_weight, + novelty_weight=self._config.novelty_weight, + half_life_last_access_days=self._config.half_life_last_access_days, + half_life_created_days=self._config.half_life_created_days, + ) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Search for memories using the Agent Memory Server. + + Performs semantic search against long-term memory with optional + recency boosting. Results are filtered by namespace (derived from + app_name) and user_id. + + Args: + app_name: The application name (used as namespace if not configured). + user_id: The user ID to filter memories. + query: The search query for semantic matching. + + Returns: + SearchMemoryResponse containing matching MemoryEntry objects. + """ + try: + recency_config = ( + self._build_recency_config() if self._config.recency_boost else None + ) + + namespace = self._config.default_namespace or app_name + + results = await self._client.search_long_term_memory( + text=query, + namespace={"eq": namespace}, + user_id={"eq": user_id}, + distance_threshold=self._config.distance_threshold, + recency=recency_config, + limit=self._config.search_top_k, + ) + + memories = [] + for record in results.memories: + content = types.Content(parts=[types.Part(text=record.text)]) + memory_entry = MemoryEntry(content=content) + memories.append(memory_entry) + + logger.info( + "Found %d memories for query '%s' (namespace=%s, user=%s)", + len(memories), + query[:50], + namespace, + user_id, + ) + return SearchMemoryResponse(memories=memories) + + except Exception as e: + logger.error("Failed to search memories: %s", e) + return SearchMemoryResponse(memories=[]) + + async def close(self): + """Close the memory service and cleanup resources.""" + if ( + "_client" in self.__dict__ + ): # Check for initialized client without triggering cached_property + await self._client.close() + # Clear the cached property + del self._client diff --git a/src/google/adk_community/memory/utils.py b/src/google/adk_community/memory/utils.py index 0b78206..48058ce 100644 --- a/src/google/adk_community/memory/utils.py +++ b/src/google/adk_community/memory/utils.py @@ -33,9 +33,8 @@ def extract_text_from_event(event) -> str: # Filter out thought parts and only extract text # This prevents metadata like thoughtSignature from being stored text_parts = [ - part.text - for part in event.content.parts + part.text + for part in event.content.parts if part.text and not part.thought ] return ' '.join(text_parts) - diff --git a/src/google/adk_community/sessions/__init__.py b/src/google/adk_community/sessions/__init__.py index 90bf28d..340ae34 100644 --- a/src/google/adk_community/sessions/__init__.py +++ b/src/google/adk_community/sessions/__init__.py @@ -15,5 +15,11 @@ """Community session services for ADK.""" from .redis_session_service import RedisSessionService +from .redis_working_memory_session_service import RedisWorkingMemorySessionService +from .redis_working_memory_session_service import RedisWorkingMemorySessionServiceConfig -__all__ = ["RedisSessionService"] +__all__ = [ + "RedisSessionService", + "RedisWorkingMemorySessionService", + "RedisWorkingMemorySessionServiceConfig", +] diff --git a/src/google/adk_community/sessions/redis_working_memory_session_service.py b/src/google/adk_community/sessions/redis_working_memory_session_service.py new file mode 100644 index 0000000..7ebdbec --- /dev/null +++ b/src/google/adk_community/sessions/redis_working_memory_session_service.py @@ -0,0 +1,440 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis Working Memory Session Service for ADK. + +This module provides session management using the Redis Agent Memory Server's +Working Memory API, offering automatic context summarization and background +memory extraction. +""" + +from __future__ import annotations + +from functools import cached_property +import logging +import time +from typing import Any +from typing import Literal +from typing import Optional +import uuid + +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override + +from google.adk_community.memory.utils import extract_text_from_event + +logger = logging.getLogger("google_adk." + __name__) + + +class RedisWorkingMemorySessionServiceConfig(BaseModel): + """Configuration for Redis Working Memory Session Service. + + Attributes: + api_base_url: Base URL of the Agent Memory Server. + timeout: HTTP request timeout in seconds. + default_namespace: Default namespace for session operations. + model_name: Model name for context window management and summarization. + context_window_max: Maximum context window tokens. + extraction_strategy: Memory extraction strategy. + extraction_strategy_config: Additional config for extraction strategy. + session_ttl_seconds: Optional TTL for session expiration. + """ + + api_base_url: str = Field(default="http://localhost:8000") + timeout: float = Field(default=30.0, gt=0.0) + default_namespace: Optional[str] = None + model_name: Optional[str] = None + context_window_max: Optional[int] = Field(default=None, ge=1) + extraction_strategy: Literal[ + "discrete", "summary", "preferences", "custom" + ] = "discrete" + extraction_strategy_config: dict = Field(default_factory=dict) + session_ttl_seconds: Optional[int] = Field(default=None, ge=1) + + +class RedisWorkingMemorySessionService(BaseSessionService): + """Session service using Redis Agent Memory Server's Working Memory API. + + This service provides session management backed by Agent Memory Server: + - Session storage in Working Memory + - Automatic context summarization when token limit exceeded + - Background memory extraction to Long-Term Memory + - Incremental message appending + - https://github.com/redis/agent-memory-server + + Requires the `agent-memory-client` package to be installed. + + Example: + ```python + from google.adk_community.sessions import ( + RedisWorkingMemorySessionService, + RedisWorkingMemorySessionServiceConfig, + ) + + config = RedisWorkingMemorySessionServiceConfig( + api_base_url="http://localhost:8000", + default_namespace="my_app", + ) + session_service = RedisWorkingMemorySessionService(config=config) + + # Use with ADK runner + runner = Runner( + agent=agent, + session_service=session_service, + ) + ``` + """ + + def __init__( + self, config: Optional[RedisWorkingMemorySessionServiceConfig] = None + ): + """Initialize the Redis Working Memory Session Service. + + Args: + config: Configuration for the service. If None, uses defaults. + """ + self._config = config or RedisWorkingMemorySessionServiceConfig() + + @cached_property + def _client(self): + """Lazily initialize and return the MemoryAPIClient.""" + try: + from agent_memory_client import MemoryAPIClient + from agent_memory_client import MemoryClientConfig + except ImportError as e: + raise ImportError( + "agent-memory-client package is required for " + "RedisWorkingMemorySessionService. " + "Install it with: pip install agent-memory-client" + ) from e + + client_config = MemoryClientConfig( + base_url=self._config.api_base_url, + timeout=self._config.timeout, + default_namespace=self._config.default_namespace, + default_model_name=self._config.model_name, + default_context_window_max=self._config.context_window_max, + ) + return MemoryAPIClient(client_config) + + def _get_namespace(self, app_name: str) -> str: + """Get namespace from config or app_name.""" + return self._config.default_namespace or app_name + + def _event_to_message(self, event: Event): + """Convert ADK Event to MemoryMessage.""" + from datetime import datetime + from datetime import timezone + + from agent_memory_client.models import MemoryMessage + + text = extract_text_from_event(event) + if not text: + return None + + role = "user" if event.author == "user" else "assistant" + # Convert event timestamp (float) to datetime for MemoryMessage + created_at = datetime.fromtimestamp(event.timestamp, tz=timezone.utc) + return MemoryMessage(role=role, content=text, created_at=created_at) + + def _working_memory_response_to_session( + self, + response, + app_name: str, + user_id: str, + ) -> Session: + """Convert WorkingMemoryResponse to ADK Session.""" + events = [] + for msg in response.messages or []: + author = "user" if msg.role == "user" else response.session_id + content = types.Content(parts=[types.Part(text=msg.content)]) + # Preserve original message timestamp if available + timestamp = ( + msg.created_at.timestamp() + if hasattr(msg, "created_at") and msg.created_at + else time.time() + ) + event = Event( + author=author, + content=content, + timestamp=timestamp, + ) + events.append(event) + + return Session( + id=response.session_id, + app_name=app_name, + user_id=user_id, + events=events, + state=response.data or {}, + last_update_time=time.time(), + ) + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Create a new session in Working Memory. + + Uses get_or_create_working_memory to prevent accidental overwrites + of existing sessions. + + Args: + app_name: Application name (used as namespace if not configured). + user_id: User identifier. + state: Initial session state. + session_id: Optional session ID (generated if not provided). + + Returns: + The created Session. + """ + from agent_memory_client.models import MemoryStrategyConfig + + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + namespace = self._get_namespace(app_name) + + strategy_config = MemoryStrategyConfig( + strategy=self._config.extraction_strategy, + config=self._config.extraction_strategy_config, + ) + + # Use get_or_create to prevent accidental overwrites + created, working_memory = await self._client.get_or_create_working_memory( + session_id=session_id, + namespace=namespace, + user_id=user_id, + long_term_memory_strategy=strategy_config, + ) + + if not created: + logger.warning( + "Session %s already exists in namespace %s, returning existing", + session_id, + namespace, + ) + # Return existing session data + return self._working_memory_response_to_session( + working_memory, app_name, user_id + ) + + # Update with initial state and TTL if provided + if state or self._config.session_ttl_seconds: + if state: + working_memory.data = state + if self._config.session_ttl_seconds: + working_memory.ttl_seconds = self._config.session_ttl_seconds + await self._client.put_working_memory( + session_id=session_id, + memory=working_memory, + user_id=user_id, + ) + + logger.info("Created session %s in namespace %s", session_id, namespace) + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state or {}, + events=[], + last_update_time=time.time(), + ) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Retrieve a session from Working Memory. + + Uses get_or_create_working_memory and checks if session was newly created + to determine if it exists. Passes model_name and context_window_max to + enable automatic context summarization when token limit is exceeded. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session ID to retrieve. + config: Optional configuration for filtering events. + + Returns: + The Session if found, None otherwise. + """ + from agent_memory_client.exceptions import MemoryNotFoundError + + try: + namespace = self._get_namespace(app_name) + # Use get_or_create to avoid deprecated get_working_memory + created, response = await self._client.get_or_create_working_memory( + session_id=session_id, + namespace=namespace, + user_id=user_id, + model_name=self._config.model_name, + context_window_max=self._config.context_window_max, + ) + + # If session was just created, it means it didn't exist before + # Delete it and return None to maintain get_session semantics + if created: + await self._client.delete_working_memory( + session_id=session_id, + namespace=namespace, + user_id=user_id, + ) + return None + + session = self._working_memory_response_to_session( + response, app_name, user_id + ) + + if config: + if config.num_recent_events: + session.events = session.events[-config.num_recent_events :] + if config.after_timestamp: + session.events = [ + e for e in session.events if e.timestamp > config.after_timestamp + ] + + return session + + except MemoryNotFoundError: + return None + except Exception as e: + logger.error("Failed to get session %s: %s", session_id, e) + return None + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + """List all sessions for a user from Working Memory. + + Args: + app_name: Application name. + user_id: User identifier. + + Returns: + ListSessionsResponse containing sessions (without events). + """ + try: + namespace = self._get_namespace(app_name) + + # SDK method: list_sessions returns SessionListResponse + # with sessions: list[str] (session IDs only) + response = await self._client.list_sessions( + namespace=namespace, + user_id=user_id, + ) + + sessions = [] + for session_id in response.sessions: + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state={}, + events=[], + last_update_time=time.time(), + ) + sessions.append(session) + + return ListSessionsResponse(sessions=sessions) + + except Exception as e: + logger.error("Failed to list sessions: %s", e) + return ListSessionsResponse(sessions=[]) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Delete a session from Working Memory. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session ID to delete. + """ + try: + namespace = self._get_namespace(app_name) + await self._client.delete_working_memory( + session_id=session_id, + namespace=namespace, + user_id=user_id, + ) + logger.info("Deleted session %s", session_id) + except Exception as e: + logger.error("Failed to delete session %s: %s", session_id, e) + + @override + async def append_event(self, session: Session, event: Event) -> Event: + """Append an event to the session in Working Memory. + + Uses the incremental append API to add a single message without + resending the full conversation history. + + Args: + session: The session to append to. + event: The event to append. + + Returns: + The appended event. + """ + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + try: + message = self._event_to_message(event) + if message: + namespace = self._get_namespace(session.app_name) + await self._client.append_messages_to_working_memory( + session_id=session.id, + messages=[message], + namespace=namespace, + user_id=session.user_id, + ) + logger.debug("Appended message to session %s", session.id) + except Exception as e: + logger.error("Failed to append event to session %s: %s", session.id, e) + + return event + + async def close(self): + """Close the session service and cleanup resources.""" + if ( + "_client" in self.__dict__ + ): # Check for initialized client without triggering cached_property + await self._client.close() + # Clear the cached property + del self._client diff --git a/src/google/adk_community/sessions/utils.py b/src/google/adk_community/sessions/utils.py index bc53d2b..132c773 100644 --- a/src/google/adk_community/sessions/utils.py +++ b/src/google/adk_community/sessions/utils.py @@ -20,18 +20,18 @@ def _json_serializer(obj): - """Fallback serializer to handle non-JSON-compatible types.""" - if isinstance(obj, set): - return list(obj) - if isinstance(obj, bytes): - try: - return base64.b64encode(obj).decode("ascii") - except Exception: - return repr(obj) - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - if isinstance(obj, Decimal): - return float(obj) + """Fallback serializer to handle non-JSON-compatible types.""" + if isinstance(obj, set): + return list(obj) + if isinstance(obj, bytes): + try: + return base64.b64encode(obj).decode("ascii") + except Exception: + return repr(obj) + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, uuid.UUID): return str(obj) + if isinstance(obj, Decimal): + return float(obj) + return str(obj) diff --git a/tests/unittests/memory/test_redis_long_term_memory_service.py b/tests/unittests/memory/test_redis_long_term_memory_service.py new file mode 100644 index 0000000..a2ca3f4 --- /dev/null +++ b/tests/unittests/memory/test_redis_long_term_memory_service.py @@ -0,0 +1,360 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory.redis_long_term_memory_service import RedisLongTermMemoryService +from google.adk_community.memory.redis_long_term_memory_service import RedisLongTermMemoryServiceConfig + + +# Create mock classes for agent_memory_client.models +class MockMemoryMessage: + + def __init__(self, role, content): + self.role = role + self.content = content + + +class MockMemoryStrategyConfig: + + def __init__(self, strategy, config=None): + self.strategy = strategy + self.config = config + + +class MockWorkingMemory: + + def __init__( + self, session_id, namespace, user_id, messages, long_term_memory_strategy + ): + self.session_id = session_id + self.namespace = namespace + self.user_id = user_id + self.messages = messages + self.long_term_memory_strategy = long_term_memory_strategy + + +class MockRecencyConfig: + + def __init__( + self, + recency_boost, + semantic_weight, + recency_weight, + freshness_weight, + novelty_weight, + half_life_last_access_days, + half_life_created_days, + ): + self.recency_boost = recency_boost + self.semantic_weight = semantic_weight + self.recency_weight = recency_weight + self.freshness_weight = freshness_weight + self.novelty_weight = novelty_weight + self.half_life_last_access_days = half_life_last_access_days + self.half_life_created_days = half_life_created_days + + +MOCK_APP_NAME = "test-app" +MOCK_USER_ID = "test-user" +MOCK_SESSION_ID = "session-1" + +MOCK_SESSION = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, + events=[ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="Hello, I like Python.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=12346, + content=types.Content( + parts=[ + types.Part(text="Python is a great programming language.") + ] + ), + ), + # Empty event, should be ignored + Event( + id="event-3", + invocation_id="inv-3", + author="user", + timestamp=12347, + ), + ], +) + +MOCK_SESSION_WITH_EMPTY_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_ID, + last_update_time=1000, +) + + +class TestRedisLongTermMemoryService: + """Tests for RedisLongTermMemoryService.""" + + @pytest.fixture(autouse=True) + def mock_agent_memory_models(self): + """Mock agent_memory_client.models module.""" + mock_models = MagicMock() + mock_models.MemoryMessage = MockMemoryMessage + mock_models.MemoryStrategyConfig = MockMemoryStrategyConfig + mock_models.WorkingMemory = MockWorkingMemory + mock_models.RecencyConfig = MockRecencyConfig + + with patch.dict(sys.modules, {"agent_memory_client.models": mock_models}): + yield mock_models + + @pytest.fixture + def mock_memory_client(self): + """Create a mock MemoryAPIClient.""" + mock_client = MagicMock() + mock_client.put_working_memory = AsyncMock() + mock_client.search_long_term_memory = AsyncMock() + mock_client.close = AsyncMock() + return mock_client + + @pytest.fixture + def memory_service(self, mock_memory_client): + """Create RedisLongTermMemoryService with mocked client.""" + service = RedisLongTermMemoryService() + # Inject the mock client by setting it in __dict__ to bypass cached_property + service.__dict__["_client"] = mock_memory_client + return service + + @pytest.fixture + def memory_service_with_config(self, mock_memory_client): + """Create RedisLongTermMemoryService with custom config.""" + config = RedisLongTermMemoryServiceConfig( + default_namespace="custom_namespace", + search_top_k=5, + recency_boost=True, + extraction_strategy="preferences", + ) + service = RedisLongTermMemoryService(config=config) + # Inject the mock client by setting it in __dict__ to bypass cached_property + service.__dict__["_client"] = mock_memory_client + return service + + @pytest.mark.asyncio + async def test_add_session_to_memory_success( + self, memory_service, mock_memory_client + ): + """Test successful addition of session to memory.""" + mock_response = MagicMock() + mock_response.context_percentage_total_used = 25.0 + mock_memory_client.put_working_memory.return_value = mock_response + + await memory_service.add_session_to_memory(MOCK_SESSION) + + mock_memory_client.put_working_memory.assert_called_once() + call_args = mock_memory_client.put_working_memory.call_args + assert call_args.kwargs["session_id"] == MOCK_SESSION_ID + assert call_args.kwargs["user_id"] == MOCK_USER_ID + + working_memory = call_args.kwargs["memory"] + assert len(working_memory.messages) == 2 + assert working_memory.messages[0].role == "user" + assert working_memory.messages[0].content == "Hello, I like Python." + assert working_memory.messages[1].role == "assistant" + assert ( + working_memory.messages[1].content + == "Python is a great programming language." + ) + + @pytest.mark.asyncio + async def test_add_session_filters_empty_events( + self, memory_service, mock_memory_client + ): + """Test that events without content are filtered out.""" + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS) + + mock_memory_client.put_working_memory.assert_not_called() + + @pytest.mark.asyncio + async def test_add_session_uses_config_namespace( + self, memory_service_with_config, mock_memory_client + ): + """Test that namespace from config is used.""" + mock_response = MagicMock() + mock_response.context_percentage_total_used = 10.0 + mock_memory_client.put_working_memory.return_value = mock_response + + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + call_args = mock_memory_client.put_working_memory.call_args + working_memory = call_args.kwargs["memory"] + assert working_memory.namespace == "custom_namespace" + + @pytest.mark.asyncio + async def test_add_session_uses_extraction_strategy( + self, memory_service_with_config, mock_memory_client + ): + """Test that extraction strategy from config is used.""" + mock_response = MagicMock() + mock_response.context_percentage_total_used = 10.0 + mock_memory_client.put_working_memory.return_value = mock_response + + await memory_service_with_config.add_session_to_memory(MOCK_SESSION) + + call_args = mock_memory_client.put_working_memory.call_args + working_memory = call_args.kwargs["memory"] + assert working_memory.long_term_memory_strategy.strategy == "preferences" + + @pytest.mark.asyncio + async def test_add_session_error_handling( + self, memory_service, mock_memory_client + ): + """Test error handling during memory addition.""" + mock_memory_client.put_working_memory.side_effect = Exception("API Error") + + # Should not raise exception, just log error + await memory_service.add_session_to_memory(MOCK_SESSION) + + @pytest.mark.asyncio + async def test_search_memory_success( + self, memory_service, mock_memory_client + ): + """Test successful memory search.""" + mock_memory = MagicMock() + mock_memory.text = "Python is a great language" + mock_results = MagicMock() + mock_results.memories = [mock_memory] + mock_memory_client.search_long_term_memory.return_value = mock_results + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="Python programming" + ) + + mock_memory_client.search_long_term_memory.assert_called_once() + call_args = mock_memory_client.search_long_term_memory.call_args + assert call_args.kwargs["text"] == "Python programming" + assert call_args.kwargs["namespace"] == {"eq": MOCK_APP_NAME} + assert call_args.kwargs["user_id"] == {"eq": MOCK_USER_ID} + + assert len(result.memories) == 1 + assert ( + result.memories[0].content.parts[0].text == "Python is a great language" + ) + + @pytest.mark.asyncio + async def test_search_memory_with_recency_boost( + self, memory_service, mock_memory_client + ): + """Test that recency config is passed when enabled.""" + mock_results = MagicMock() + mock_results.memories = [] + mock_memory_client.search_long_term_memory.return_value = mock_results + + await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test query" + ) + + call_args = mock_memory_client.search_long_term_memory.call_args + recency = call_args.kwargs["recency"] + assert recency is not None + assert recency.recency_boost is True + assert recency.semantic_weight == 0.8 + assert recency.recency_weight == 0.2 + + @pytest.mark.asyncio + async def test_search_memory_without_recency_boost(self, mock_memory_client): + """Test that recency config is None when disabled.""" + config = RedisLongTermMemoryServiceConfig(recency_boost=False) + service = RedisLongTermMemoryService(config=config) + # Inject the mock client by setting it in __dict__ to bypass cached_property + service.__dict__["_client"] = mock_memory_client + + mock_results = MagicMock() + mock_results.memories = [] + mock_memory_client.search_long_term_memory.return_value = mock_results + + await service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test query" + ) + + call_args = mock_memory_client.search_long_term_memory.call_args + assert call_args.kwargs["recency"] is None + + @pytest.mark.asyncio + async def test_search_memory_respects_top_k( + self, memory_service_with_config, mock_memory_client + ): + """Test that config.search_top_k is used.""" + mock_results = MagicMock() + mock_results.memories = [] + mock_memory_client.search_long_term_memory.return_value = mock_results + + await memory_service_with_config.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test query" + ) + + call_args = mock_memory_client.search_long_term_memory.call_args + assert call_args.kwargs["limit"] == 5 + + @pytest.mark.asyncio + async def test_search_memory_error_handling( + self, memory_service, mock_memory_client + ): + """Test graceful error handling during memory search.""" + mock_memory_client.search_long_term_memory.side_effect = Exception( + "API Error" + ) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="test query" + ) + + assert len(result.memories) == 0 + + @pytest.mark.asyncio + async def test_close(self, memory_service, mock_memory_client): + """Test closing the service.""" + await memory_service.close() + + mock_memory_client.close.assert_called_once() + assert ( + not hasattr(memory_service, "client") + or "client" not in memory_service.__dict__ + ) + + def test_import_error_handling(self): + """Test that ImportError is raised when agent-memory-client is not installed.""" + service = RedisLongTermMemoryService() + + with patch.dict("sys.modules", {"agent_memory_client": None}): + with pytest.raises(ImportError, match="agent-memory-client"): + # Access the client property which will trigger the import + _ = service._client diff --git a/tests/unittests/sessions/test_redis_working_memory_session_service.py b/tests/unittests/sessions/test_redis_working_memory_session_service.py new file mode 100644 index 0000000..67f4748 --- /dev/null +++ b/tests/unittests/sessions/test_redis_working_memory_session_service.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for RedisWorkingMemorySessionService.""" + +import time +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.session import Session +from google.genai import types +import pytest +import pytest_asyncio + + +class TestRedisWorkingMemorySessionServiceConfig: + """Test cases for RedisWorkingMemorySessionServiceConfig.""" + + def test_default_config(self): + """Test default configuration values.""" + from google.adk_community.sessions import RedisWorkingMemorySessionServiceConfig + + config = RedisWorkingMemorySessionServiceConfig() + + assert config.api_base_url == "http://localhost:8000" + assert config.timeout == 30.0 + assert config.default_namespace is None + assert config.model_name is None + assert config.context_window_max is None + assert config.extraction_strategy == "discrete" + assert config.extraction_strategy_config == {} + + def test_custom_config(self): + """Test custom configuration values.""" + from google.adk_community.sessions import RedisWorkingMemorySessionServiceConfig + + config = RedisWorkingMemorySessionServiceConfig( + api_base_url="http://custom:9000", + timeout=60.0, + default_namespace="my_namespace", + model_name="gpt-4", + context_window_max=8000, + extraction_strategy="summary", + extraction_strategy_config={"key": "value"}, + ) + + assert config.api_base_url == "http://custom:9000" + assert config.timeout == 60.0 + assert config.default_namespace == "my_namespace" + assert config.model_name == "gpt-4" + assert config.context_window_max == 8000 + assert config.extraction_strategy == "summary" + assert config.extraction_strategy_config == {"key": "value"} + + +class TestRedisWorkingMemorySessionService: + """Test cases for RedisWorkingMemorySessionService.""" + + @pytest_asyncio.fixture + async def mock_client(self): + """Create a mock MemoryAPIClient.""" + mock = AsyncMock() + mock.close = AsyncMock() + return mock + + @pytest_asyncio.fixture + async def service(self, mock_client): + """Create a RedisWorkingMemorySessionService with mocked client.""" + from google.adk_community.sessions import RedisWorkingMemorySessionService + from google.adk_community.sessions import RedisWorkingMemorySessionServiceConfig + + config = RedisWorkingMemorySessionServiceConfig( + api_base_url="http://localhost:8000", + default_namespace="test_namespace", + ) + svc = RedisWorkingMemorySessionService(config=config) + # Inject mock client + svc.__dict__["_client"] = mock_client + return svc + + @pytest.mark.asyncio + async def test_create_session(self, service, mock_client): + """Test session creation.""" + mock_wm = MagicMock() + mock_wm.session_id = "generated_id" + mock_wm.messages = [] + mock_wm.data = {} + mock_client.get_or_create_working_memory = AsyncMock( + return_value=(True, mock_wm) + ) + mock_client.put_working_memory = AsyncMock() + + session = await service.create_session( + app_name="test_app", + user_id="test_user", + state={"key": "value"}, + ) + + assert session.app_name == "test_app" + assert session.user_id == "test_user" + assert session.state == {"key": "value"} + assert session.events == [] + assert session.id is not None + mock_client.get_or_create_working_memory.assert_called_once() + # put_working_memory called to update state + mock_client.put_working_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_with_custom_id(self, service, mock_client): + """Test session creation with custom session ID.""" + mock_wm = MagicMock() + mock_wm.session_id = "custom_session_id" + mock_wm.messages = [] + mock_wm.data = {} + mock_client.get_or_create_working_memory = AsyncMock( + return_value=(True, mock_wm) + ) + + session = await service.create_session( + app_name="test_app", + user_id="test_user", + session_id="custom_session_id", + ) + + assert session.id == "custom_session_id" + mock_client.get_or_create_working_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_get_session(self, service, mock_client): + """Test session retrieval.""" + mock_response = MagicMock() + mock_response.session_id = "test_session" + mock_response.messages = [] + mock_response.data = {"key": "value"} + # Return (created=False, response) to indicate existing session + mock_client.get_or_create_working_memory = AsyncMock( + return_value=(False, mock_response) + ) + + session = await service.get_session( + app_name="test_app", + user_id="test_user", + session_id="test_session", + ) + + assert session is not None + assert session.id == "test_session" + assert session.state == {"key": "value"} + mock_client.get_or_create_working_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_get_session_not_found(self, service, mock_client): + """Test session retrieval when session doesn't exist.""" + # Return (created=True, response) to indicate new session was created + mock_response = MagicMock() + mock_response.session_id = "nonexistent" + mock_response.messages = [] + mock_response.data = {} + mock_client.get_or_create_working_memory = AsyncMock( + return_value=(True, mock_response) + ) + mock_client.delete_working_memory = AsyncMock() + + session = await service.get_session( + app_name="test_app", + user_id="test_user", + session_id="nonexistent", + ) + + assert session is None + + @pytest.mark.asyncio + async def test_list_sessions(self, service, mock_client): + """Test listing sessions.""" + mock_response = MagicMock() + mock_response.sessions = ["session1", "session2", "session3"] + mock_client.list_sessions = AsyncMock(return_value=mock_response) + + result = await service.list_sessions( + app_name="test_app", + user_id="test_user", + ) + + assert len(result.sessions) == 3 + assert result.sessions[0].id == "session1" + assert result.sessions[1].id == "session2" + assert result.sessions[2].id == "session3" + + @pytest.mark.asyncio + async def test_delete_session(self, service, mock_client): + """Test session deletion.""" + mock_client.delete_working_memory = AsyncMock() + + await service.delete_session( + app_name="test_app", + user_id="test_user", + session_id="test_session", + ) + + mock_client.delete_working_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_append_event(self, service, mock_client): + """Test appending an event to a session.""" + mock_client.append_messages_to_working_memory = AsyncMock() + + session = Session( + id="test_session", + app_name="test_app", + user_id="test_user", + state={}, + events=[], + last_update_time=time.time(), + ) + + event = Event( + author="user", + content=types.Content(parts=[types.Part(text="Hello")]), + timestamp=time.time(), + ) + + result = await service.append_event(session=session, event=event) + + assert result == event + mock_client.append_messages_to_working_memory.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_existing_returns_existing( + self, service, mock_client + ): + """Test that creating a session with existing ID returns existing session.""" + mock_wm = MagicMock() + mock_wm.session_id = "existing_session" + mock_wm.messages = [] + mock_wm.data = {"existing": "data"} + # created=False means session already exists + mock_client.get_or_create_working_memory = AsyncMock( + return_value=(False, mock_wm) + ) + + session = await service.create_session( + app_name="test_app", + user_id="test_user", + session_id="existing_session", + ) + + assert session.id == "existing_session" + assert session.state == {"existing": "data"} + + @pytest.mark.asyncio + async def test_close(self, service, mock_client): + """Test closing the service.""" + await service.close() + + mock_client.close.assert_called_once()