diff --git a/contributing/samples/redis_vl_search/README.md b/contributing/samples/redis_vl_search/README.md new file mode 100644 index 0000000..246e960 --- /dev/null +++ b/contributing/samples/redis_vl_search/README.md @@ -0,0 +1,215 @@ +# RedisVL Search Agent + +This sample demonstrates using Redis search tools to give an ADK agent +access to a Redis-based knowledge base with multiple search capabilities. + +## What This Sample Shows + +- Setting up a Redis vector index with a schema +- Using 3 Redis search tools in one agent (4th requires Redis 8.4+): + - **RedisVectorSearchTool**: Semantic similarity search (KNN) - finds conceptually similar content + - **RedisTextSearchTool**: Full-text keyword search (BM25) - matches exact terms and phrases + - **RedisRangeSearchTool**: Distance threshold search - returns ALL docs within a relevance radius + - **RedisHybridSearchTool**: Combined vector + text search (requires Redis 8.4+ and redis-py 7.1+) +- Integrating RedisVL with an ADK agent + +## Prerequisites + +1. **Redis Stack** running locally (or Redis Cloud with Search capability) + ```bash + # Using Docker + docker run -d --name redis-stack -p 6379:6379 redis/redis-stack:latest + ``` + +2. **No API keys needed for embeddings** - uses Redis' open-source `redis/langcache-embed-v2` model (768 dimensions) + +## Setup + +1. Install dependencies: + ```bash + pip install "google-adk-community[redis-vl]" + ``` + +2. Download NLTK stopwords (required for keyword search): + ```bash + python -c "import nltk; nltk.download('stopwords')" + ``` + +3. Set environment variables (or create a `.env` file): + ```bash + export REDIS_URL=redis://localhost:6379 + export GOOGLE_API_KEY=your-google-api-key # For Gemini LLM + ``` + +4. Load sample data into Redis: + ```bash + cd contributing/samples/redis_vl_search + python load_data.py + ``` + +5. Run the agent: + ```bash + cd contributing/samples/redis_vl_search + adk web + ``` + +## Files + +| File | Description | +|------|-------------| +| `schema.yaml` | Redis index schema defining document structure | +| `load_data.py` | Script to populate Redis with sample documents | +| `redis_vl_search_agent/agent.py` | Agent definition with all Redis search tools | + +## How It Works + +1. **Schema Definition** (`schema.yaml`): Defines the index structure with fields + for title, content, URL, category, and a vector embedding field. + +2. **Data Loading** (`load_data.py`): Populates Redis with sample documents about + Redis and ADK, embedding the content using Redis' langcache-embed-v2 model. + +3. **Agent** (`redis_vl_search_agent/agent.py`): Creates an agent with access to + multiple search tools for different use cases. + +## Search Tools + +### semantic_search (RedisVectorSearchTool) +**Best for:** Conceptual questions, natural language queries, finding similar content. + +**How it works:** Converts query to vector embedding, finds K nearest neighbors by cosine similarity. + +**Returns:** Top-K most similar documents (default: 5). + +**Example queries:** +- "What is Redis?" → finds docs about Redis even if they don't say "What is Redis" +- "How do I build a chatbot?" → finds docs about "intelligent assistants", "conversational AI" +- "Fast database for caching" → finds Redis docs even without exact keyword match + +### keyword_search (RedisTextSearchTool) +**Best for:** Exact terms, acronyms, technical jargon, API names, error messages. + +**How it works:** BM25 text scoring algorithm - matches exact tokens, weighs by term frequency. + +**Returns:** Top-K documents ranked by keyword relevance. + +**Example queries:** +- "HNSW algorithm" → exact match on "HNSW" acronym +- "BM25 formula" → finds docs containing "BM25" +- "VectorQuery class" → API/class name lookup +- "RRF ranking" → technical term that needs exact match + +### range_search (RedisRangeSearchTool) +**Best for:** Exhaustive retrieval, comprehensive coverage, finding ALL related documents. + +**How it works:** Returns ALL documents within a distance threshold (not just top-K). + +**Returns:** Variable number - every document above the relevance bar. + +**Use when:** +- User wants "everything" about a topic +- Comprehensive research needed +- Quality filtering (only highly relevant docs) +- Clustering/grouping similar content + +**Example queries:** +- "Tell me everything about RAG pipelines" → returns all RAG-related docs +- "All Redis data structures" → comprehensive list +- "Complete guide to embeddings" → exhaustive retrieval + +### hybrid_search (RedisHybridSearchTool) +**Best for:** Queries that benefit from both semantic understanding AND exact keyword matching. + +**How it works:** Combines vector similarity + BM25 text scores using RRF or linear weighting. + +**Requires:** Redis 8.4+ and redis-py 7.1+ + +## Example Queries + +Once running, try asking the agent: + +| Query | Expected Tool | Why | +|-------|---------------|-----| +| "What is Redis?" | semantic_search | Conceptual question | +| "HNSW algorithm details" | keyword_search | Technical acronym | +| "Tell me everything about RAG" | range_search | Exhaustive retrieval | +| "How do I build a chatbot?" | semantic_search | Natural language | +| "BM25 formula" | keyword_search | Exact term lookup | +| "All vector search methods" | range_search | Comprehensive coverage | + +## Customization + +### Using a Different Vectorizer + +```python +from redisvl.utils.vectorize import HuggingFaceTextVectorizer + +vectorizer = HuggingFaceTextVectorizer(model="sentence-transformers/all-MiniLM-L6-v2") +``` + +Note: Update `dims` in `schema.yaml` to match your model's embedding dimensions. + +### Adding Filters + +You can add filter expressions to narrow search results: + +```python +from redisvl.query.filter import Tag +from google.adk_community.tools.redis import RedisVectorSearchTool, RedisVectorQueryConfig + +config = RedisVectorQueryConfig(num_results=5) +redis_search = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + return_fields=["title", "content", "url", "category"], + filter_expression=Tag("category") == "redis", # Only search Redis docs +) +``` + +See [RedisVL Filter documentation](https://docs.redisvl.com/api/filter.html) for more filter options. + +### Advanced Query Options + +`RedisVectorSearchTool` uses a `RedisVectorQueryConfig` object for query parameters: + +```python +from google.adk_community.tools.redis import RedisVectorSearchTool, RedisVectorQueryConfig +from redisvl.query.filter import Tag + +# Configure query parameters via config object +config = RedisVectorQueryConfig( + num_results=10, + # Query tuning + dtype="float32", # Vector dtype + return_score=True, # Include similarity score + normalize_vector_distance=True, # Convert to 0-1 similarity + # Hybrid filtering + hybrid_policy="BATCHES", # or "ADHOC_BF" + batch_size=100, # For BATCHES policy + # HNSW tuning + ef_runtime=150, # Higher = better recall, slower + epsilon=0.01, # Range search approximation + # SVS-VAMANA tuning + search_window_size=20, # Search window size + use_search_history="AUTO", # "OFF", "ON", or "AUTO" + search_buffer_capacity=30, # 2-level compression tuning +) + +redis_search = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + return_fields=["title", "content"], + filter_expression=Tag("category") == "redis", +) +``` + +See [RedisVL Query documentation](https://docs.redisvl.com/api/query.html) for details. + +### Connecting to Redis Cloud + +```bash +export REDIS_URL=redis://default:password@your-redis-cloud-host:port +``` + diff --git a/contributing/samples/redis_vl_search/load_data.py b/contributing/samples/redis_vl_search/load_data.py new file mode 100644 index 0000000..c175413 --- /dev/null +++ b/contributing/samples/redis_vl_search/load_data.py @@ -0,0 +1,259 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load sample data into Redis for the redis_search_agent demo. + +Data is designed to demonstrate when different search tools excel: +- Semantic search: conceptual queries, synonyms, paraphrasing +- Keyword search: technical terms, acronyms, exact phrases +- Range search: finding all highly relevant docs above threshold +- Hybrid search: combining both approaches +""" + +import os +from pathlib import Path + +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +# Documents designed to showcase different search tool strengths +SAMPLE_DOCS = [ + # === SEMANTIC SEARCH DEMOS === + # These use varied vocabulary - semantic search finds them even with different words + { + "title": "Introduction to Redis", + "content": ( + "Redis is a lightning-fast in-memory data store. It excels at caching, " + "session management, and real-time analytics. Think of it as a Swiss Army " + "knife for data - versatile, quick, and reliable." + ), + "url": "https://redis.io/docs/about/", + "category": "redis", + "doc_type": "reference", + "difficulty": "beginner", + }, + { + "title": "Understanding Vector Databases", + "content": ( + "Vector databases store numerical representations of data called embeddings. " + "These embeddings capture semantic meaning, enabling similarity search. " + "Applications include recommendation engines, image search, and chatbots." + ), + "url": "https://redis.io/docs/vectors/", + "category": "concepts", + "doc_type": "reference", + "difficulty": "intermediate", + }, + { + "title": "Building Intelligent Assistants", + "content": ( + "Modern AI assistants combine language models with external knowledge. " + "They can search databases, call APIs, and maintain conversation context. " + "The key is giving them the right tools for each task." + ), + "url": "https://google.github.io/adk-docs/agents/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + # === KEYWORD SEARCH DEMOS === + # These contain specific technical terms that keyword search finds precisely + { + "title": "HNSW Algorithm Deep Dive", + "content": ( + "HNSW (Hierarchical Navigable Small World) is the algorithm Redis uses for " + "approximate nearest neighbor search. It builds a multi-layer graph where " + "each layer has exponentially fewer nodes. Search starts at the top layer " + "and navigates down. Parameters: M (connections per node), EF (search width)." + ), + "url": "https://redis.io/docs/hnsw/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "BM25 Scoring Explained", + "content": ( + "BM25 (Best Matching 25) is a ranking function for full-text search. " + "It improves on TF-IDF by adding document length normalization and term " + "frequency saturation. Redis supports BM25STD and BM25 scorers. " + "Formula: score = IDF * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl/avgdl))" + ), + "url": "https://redis.io/docs/bm25/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "RRF: Reciprocal Rank Fusion", + "content": ( + "RRF (Reciprocal Rank Fusion) combines results from multiple search methods. " + "Score = sum(1 / (k + rank)) across all methods. Default k=60. " + "Used in hybrid search to merge vector and keyword results. " + "Alternative: LINEAR combination with alpha weighting." + ), + "url": "https://redis.io/docs/hybrid/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "RedisVL API Reference", + "content": ( + "Key classes: SearchIndex, AsyncSearchIndex, VectorQuery, TextQuery, " + "HybridQuery, VectorRangeQuery. Vectorizers: HFTextVectorizer (recommended, " + "uses redis/langcache-embed-v2 model, 768 dimensions). Methods: index.query(), " + "index.search(), index.load(), index.create(overwrite=True)." + ), + "url": "https://redis.io/docs/redisvl/api/", + "category": "redis", + "doc_type": "api", + "difficulty": "intermediate", + }, + # === RANGE SEARCH DEMOS === + # Highly related docs - range search finds all above threshold + { + "title": "RAG Architecture Overview", + "content": ( + "Retrieval-Augmented Generation (RAG) enhances LLMs with external knowledge. " + "Step 1: Embed the user query. Step 2: Search vector database for relevant " + "documents. Step 3: Include retrieved context in LLM prompt. " + "Step 4: Generate grounded response." + ), + "url": "https://redis.io/solutions/rag/", + "category": "concepts", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "RAG Best Practices", + "content": ( + "Tips for effective RAG: chunk documents appropriately (512-1024 tokens), " + "use hybrid search for better recall, rerank results before prompting, " + "include metadata for filtering, monitor retrieval quality metrics." + ), + "url": "https://redis.io/solutions/rag/best-practices/", + "category": "concepts", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "RAG vs Fine-tuning", + "content": ( + "RAG retrieves knowledge at query time; fine-tuning bakes it into weights. " + "RAG pros: up-to-date info, traceable sources, no training cost. " + "Fine-tuning pros: faster inference, specialized behavior. " + "Often combined: fine-tune for style, RAG for facts." + ), + "url": "https://redis.io/solutions/rag/comparison/", + "category": "concepts", + "doc_type": "reference", + "difficulty": "advanced", + }, + # === ADK SPECIFIC === + { + "title": "ADK Tool Development", + "content": ( + "Create custom tools by subclassing BaseTool. Implement run_async() for " + "the tool logic and _get_declaration() for the schema. Tools receive " + "args dict and ToolContext. Return results as dict or string." + ), + "url": "https://google.github.io/adk-docs/tools/custom/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "ADK Agent Orchestration", + "content": ( + "Agents can delegate to sub-agents using transfer_to_agent(). " + "Define agent hierarchy with parent/child relationships. " + "Use instruction prompts to guide when delegation occurs. " + "Sub-agents inherit tools from parent unless overridden." + ), + "url": "https://google.github.io/adk-docs/agents/orchestration/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "advanced", + }, + # === FAQ STYLE (good for both search types) === + { + "title": "FAQ: Redis Connection Issues", + "content": ( + "Q: Connection refused error? A: Check Redis is running on the correct " + "port (default 6379). Q: Authentication failed? A: Set REDIS_PASSWORD " + "env var or pass password to connect(). Q: Timeout errors? A: Increase " + "socket_timeout parameter or check network latency." + ), + "url": "https://redis.io/docs/faq/connection/", + "category": "redis", + "doc_type": "faq", + "difficulty": "beginner", + }, + { + "title": "FAQ: Embedding Dimensions Mismatch", + "content": ( + "Q: Dimension mismatch error? A: Ensure query embeddings match index " + "dimensions. Common dimensions: OpenAI ada-002 (1536), langcache-embed-v2 " + "(768), sentence-transformers (384-768). Check schema.yaml dims field." + ), + "url": "https://redis.io/docs/faq/vectors/", + "category": "redis", + "doc_type": "faq", + "difficulty": "beginner", + }, +] + + +def load_data() -> None: + """Load sample documents into Redis with embeddings.""" + schema_path = Path(__file__).parent / "schema.yaml" + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + + print(f"Connecting to Redis at {redis_url}") + index = SearchIndex.from_yaml(str(schema_path)) + index.connect(redis_url) + + print("Creating index (will overwrite if exists)...") + index.create(overwrite=True) + + print("Generating embeddings...") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + docs_with_embeddings = [] + for doc in SAMPLE_DOCS: + embedding = vectorizer.embed(doc["content"], as_buffer=True) + docs_with_embeddings.append({**doc, "embedding": embedding}) + print(f" [{doc['doc_type']:9}] {doc['title']}") + + print("\nLoading documents into Redis...") + index.load(docs_with_embeddings) + + print(f""" +Loaded {len(SAMPLE_DOCS)} documents. + +Try these queries to see different tools in action: + - "What is Redis?" → semantic_search (conceptual) + - "HNSW algorithm" → keyword_search (exact term) + - "Tell me everything about RAG" → range_search (exhaustive retrieval) + - "BM25 formula" → keyword_search (technical) + - "building chatbots" → semantic_search (synonym for assistants) + +Run: adk web +""") + + +if __name__ == "__main__": + load_data() + diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py new file mode 100644 index 0000000..ca65aef --- /dev/null +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis search agent sample.""" + +from .agent import root_agent + +__all__ = ["root_agent"] + diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py new file mode 100644 index 0000000..d678e4f --- /dev/null +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis search agent demonstrating vector, text, range, and hybrid search tools.""" + +import os +from pathlib import Path + +from dotenv import load_dotenv +from google.adk import Agent +from google.adk_community.tools.redis import ( + RedisRangeQueryConfig, + RedisRangeSearchTool, + RedisTextQueryConfig, + RedisTextSearchTool, + RedisVectorQueryConfig, + RedisVectorSearchTool, +) +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +SCHEMA_PATH = Path(__file__).parent.parent / "schema.yaml" +RETURN_FIELDS = ["title", "content", "url", "category", "doc_type", "difficulty"] + +INSTRUCTION = """You are a helpful assistant with a technical knowledge base. + +You have 3 search tools, each optimized for different query types: + +## semantic_search (Vector KNN) +Best for: Conceptual questions, natural language queries, finding similar content. +How it works: Converts query to vector embedding, finds K nearest neighbors by cosine similarity. +Returns: Top-K most similar documents (default: 5). +Example queries: + - "What is Redis?" → finds docs about Redis even if they don't say "What is Redis" + - "How do I build a chatbot?" → finds docs about "intelligent assistants", "conversational AI" + - "Fast database for caching" → finds Redis docs even without exact keyword match + - "Machine learning embeddings" → finds semantically related vector/AI content + +## keyword_search (BM25 Full-Text) +Best for: Exact terms, acronyms, technical jargon, API names, error messages. +How it works: BM25 text scoring algorithm - matches exact tokens, weighs by term frequency. +Returns: Top-K documents ranked by keyword relevance. +Example queries: + - "HNSW algorithm" → exact match on "HNSW" acronym + - "BM25 formula" → finds docs containing "BM25" + - "VectorQuery class" → API/class name lookup + - "RRF ranking" → technical term that needs exact match + - "ConnectionError redis" → error message lookup + +## range_search (Vector Range/Threshold) +Best for: Exhaustive retrieval, comprehensive coverage, finding ALL related documents. +How it works: Returns ALL documents within a distance threshold (not just top-K). +Returns: Variable number - every document above the relevance bar. +Use when: + - User wants "everything" about a topic: "Tell me everything about RAG" + - Comprehensive research: "All vector search algorithms" + - Quality filtering: Only docs highly relevant to the query + - Clustering/grouping: Find all docs in a semantic neighborhood + - Deduplication: Find near-duplicate content +Example queries: + - "Tell me everything about RAG pipelines" → returns all RAG-related docs + - "All Redis data structures" → comprehensive list + - "Complete guide to embeddings" → exhaustive retrieval + - "Every caching strategy" → all related documents, not just top 5 + +## Strategy +1. Start with the most appropriate tool based on query type +2. If results are poor or incomplete, try another tool +3. For broad topics, consider range_search first, then refine with others +4. For technical terms/acronyms, prefer keyword_search +5. For natural questions, prefer semantic_search + +Always cite sources with title and difficulty level.""" + + +def get_index(schema_path: Path, redis_url: str) -> SearchIndex: + """Create and connect to Redis search index.""" + index = SearchIndex.from_yaml(str(schema_path)) + index.connect(redis_url) + return index + + +def get_search_tools(index: SearchIndex, vectorizer: HFTextVectorizer) -> list: + """Create search tools for the agent.""" + # Config objects group query-specific parameters + vector_config = RedisVectorQueryConfig(num_results=5) + text_config = RedisTextQueryConfig( + text_field_name="content", + num_results=5, + text_scorer="BM25STD", + ) + range_config = RedisRangeQueryConfig(distance_threshold=0.5) + + return [ + RedisVectorSearchTool( + name="semantic_search", + description="Semantic similarity search for conceptual queries.", + index=index, + vectorizer=vectorizer, + config=vector_config, + return_fields=RETURN_FIELDS, + ), + RedisTextSearchTool( + name="keyword_search", + description="Keyword search for exact terms and phrases.", + index=index, + config=text_config, + return_fields=RETURN_FIELDS, + ), + RedisRangeSearchTool( + name="range_search", + description=( + "Vector range search - returns ALL documents within a semantic distance " + "threshold (not just top-K). Use for: exhaustive topic retrieval, " + "comprehensive coverage, quality-filtered results, finding all related " + "documents, clustering similar content, or when you need everything " + "above a relevance bar rather than a fixed number of results." + ), + index=index, + vectorizer=vectorizer, + config=range_config, + return_fields=RETURN_FIELDS, + ), + ] + + +def create_agent() -> Agent: + """Create the Redis search agent.""" + load_dotenv() + + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + index = get_index(SCHEMA_PATH, redis_url) + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + tools = get_search_tools(index, vectorizer) + + return Agent( + model="gemini-2.5-flash", + name="root_agent", + instruction=INSTRUCTION, + tools=tools, + ) + + +# Required for `adk web` to discover the agent +root_agent = create_agent() + + +if __name__ == "__main__": + # For local testing + print(f"Agent '{root_agent.name}' created with {len(root_agent.tools)} tools") + for tool in root_agent.tools: + print(f" - {tool.name}") + diff --git a/contributing/samples/redis_vl_search/schema.yaml b/contributing/samples/redis_vl_search/schema.yaml new file mode 100644 index 0000000..4049345 --- /dev/null +++ b/contributing/samples/redis_vl_search/schema.yaml @@ -0,0 +1,50 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Redis index schema for the knowledge base sample. +# Supports vector, text, range, and hybrid search demonstrations. + +version: "0.1.0" + +index: + name: adk_knowledge_base + prefix: doc + +fields: + - name: title + type: text + + - name: content + type: text + + - name: url + type: tag + + - name: category + type: tag # redis, adk, concepts, tutorials + + - name: doc_type + type: tag # reference, tutorial, faq, api + + - name: difficulty + type: tag # beginner, intermediate, advanced + + - name: embedding + type: vector + attrs: + algorithm: hnsw + dims: 768 + distance_metric: cosine + datatype: float32 + diff --git a/github_issue_redis_search_tools.md b/github_issue_redis_search_tools.md new file mode 100644 index 0000000..aafe801 --- /dev/null +++ b/github_issue_redis_search_tools.md @@ -0,0 +1,92 @@ +# [Proposal] Add RedisVL Search Tools for Knowledge Base Retrieval + +## Is your feature request related to a problem? Please describe. + +ADK has retrieval tools for Google Cloud (Vertex AI Search, Discovery Engine), but **no self-hosted option** for developers who need full control over their infrastructure. RAG is a fundamental pattern for building useful agents, and Redis is already running in most organizations, these tools let developers add vector search, keyword search, and hybrid retrieval to their existing Redis infrastructure. [RedisVL](https://github.com/redis/redis-vl-python) is the official Redis Vector Search library that's also well adopted. + +Moreover, the repo has `RedisSessionService` for session persistence, but **there are no tools for agents to search Redis-based knowledge bases** without building it themselves. + +## Describe the solution you'd like + +Add `redisvl` as an optional dependency with four search tools that wrap RedisVL's query capabilities as ADK `BaseTool` implementations. + +**pyproject.toml change:** +```toml +[project.optional-dependencies] +redis-vl = [ + "redisvl>=0.13.2", + "nltk>=3.8.0", + "sentence-transformers>=2.2.0", +] +``` + +**Installation:** +```bash +# Existing functionality unchanged +pip install google-adk-community + +# Opt-in to vector search capabilities +pip install google-adk-community[redis-vl] +``` + +This aligns with the community repository's stated philosophy: + +> "This approach allows the **core ADK to remain stable and lightweight**, while giving the community the freedom to build and share powerful extensions." + +### Tools Provided + +| Tool | Search Type | Use Case | +|------|-------------|----------| +| `RedisVectorSearchTool` | KNN vector similarity | Semantic/conceptual queries | +| `RedisTextSearchTool` | BM25 full-text | Exact terms, acronyms, API names | +| `RedisHybridSearchTool` | Vector + BM25 combined | Best of both worlds | +| `RedisRangeSearchTool` | Distance threshold | Exhaustive retrieval, quality filtering | + +### Developer Experience + +```python +from google.adk import Agent +from google.adk_community.tools.redis import RedisVectorSearchTool +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +index = SearchIndex.from_yaml("schema.yaml") +index.connect("redis://localhost:6379") +vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + +tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=5, + return_fields=["title", "content", "url"], +) + +agent = Agent(model="gemini-2.5-flash", tools=[tool]) +``` + +### Common Features Across All Tools + +- **Filtering**: Tag, numeric, and geo filters via `filter_expression` +- **Field selection**: Control returned fields via `return_fields` +- **Async support**: Works with both `SearchIndex` and `AsyncSearchIndex` +- **Score normalization**: Convert distances to 0-1 similarity via `normalize_vector_distance=True` +- **Full parameter exposure**: All RedisVL query parameters are configurable + +## Describe alternatives you've considered + +1. **Implement vector search with raw Redis commands** — Would duplicate existing, maintained code in RedisVL. Users would get a degraded experience compared to using RedisVL directly. + +2. **Require users to install RedisVL separately** — Creates friction and doesn't provide ADK-native abstractions like `BaseTool` wrappers with proper function declarations for LLMs. + +3. **Use a different vector database** — Redis is already widely deployed in many organizations. Adding vector search to existing Redis infrastructure is lower friction than adopting a new database. + + + +## Why RedisVL? + +[RedisVL](https://github.com/redis/redis-vl-python) is the official Redis vector library (~50MB footprint). It provides: +- Schema-driven index management +- Multiple query types (vector, text, hybrid, range) +- Built-in vectorizers (HuggingFace, OpenAI, Cohere, etc.) +- Both sync and async APIs + diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..d4f70ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "google-adk", # Google ADK "httpx>=0.27.0, <1.0.0", # For OpenMemory service - "redis>=5.0.0, <6.0.0", # Redis for session storage + "redis>=5.0.0", # Redis for session storage # go/keep-sorted end "orjson>=3.11.3", ] @@ -45,6 +45,11 @@ test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] +redis-vl = [ + "redisvl>=0.13.2", + "nltk>=3.8.0", + "sentence-transformers>=2.2.0", +] [tool.pyink] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..98c0e33 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -15,4 +15,5 @@ from . import memory from . import sessions from . import version + __version__ = version.__version__ diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c..fbcf66b 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -21,4 +21,3 @@ "OpenMemoryService", "OpenMemoryServiceConfig", ] - diff --git a/src/google/adk_community/memory/open_memory_service.py b/src/google/adk_community/memory/open_memory_service.py index 92c1ae6..4e949cf 100644 --- a/src/google/adk_community/memory/open_memory_service.py +++ b/src/google/adk_community/memory/open_memory_service.py @@ -19,27 +19,27 @@ from typing import Optional from typing import TYPE_CHECKING -import httpx -from google.genai import types -from pydantic import BaseModel -from pydantic import Field -from typing_extensions import override - from google.adk.memory import _utils from google.adk.memory.base_memory_service import BaseMemoryService from google.adk.memory.base_memory_service import SearchMemoryResponse from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +import httpx +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override from .utils import extract_text_from_event if TYPE_CHECKING: from google.adk.sessions.session import Session -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) + class OpenMemoryService(BaseMemoryService): """Memory service implementation using OpenMemory. - + See https://openmemory.cavira.app/ for more information. """ @@ -55,7 +55,7 @@ def __init__( base_url: Base URL of the OpenMemory instance (default: http://localhost:3000). api_key: API key for authentication. **Required** - must be provided. config: OpenMemoryServiceConfig instance. If None, uses defaults. - + Raises: ValueError: If api_key is not provided or is empty. """ @@ -64,7 +64,7 @@ def __init__( "api_key is required for OpenMemory. " "Provide an API key when initializing OpenMemoryService." ) - self._base_url = base_url.rstrip('/') + self._base_url = base_url.rstrip("/") self._api_key = api_key self._config = config or OpenMemoryServiceConfig() @@ -81,14 +81,12 @@ def _determine_salience(self, author: Optional[str]) -> float: else: return self._config.default_salience - def _prepare_memory_data( - self, event, content_text: str, session - ) -> dict: + def _prepare_memory_data(self, event, content_text: str, session) -> dict: """Prepare memory data structure for OpenMemory API.""" timestamp_str = None if event.timestamp: timestamp_str = _utils.format_timestamp(event.timestamp) - + # Embed author and timestamp in content for search retrieval # Format: [Author: user, Time: 2025-11-04T10:32:01] Content text enriched_content = content_text @@ -97,11 +95,11 @@ def _prepare_memory_data( metadata_parts.append(f"Author: {event.author}") if timestamp_str: metadata_parts.append(f"Time: {timestamp_str}") - + if metadata_parts: metadata_prefix = "[" + ", ".join(metadata_parts) + "] " enriched_content = metadata_prefix + content_text - + metadata = { "app_name": session.app_name, "user_id": session.user_id, @@ -110,13 +108,13 @@ def _prepare_memory_data( "invocation_id": event.invocation_id, "author": event.author, "timestamp": event.timestamp, - "source": "adk_session" + "source": "adk_session", } - + memory_data = { "content": enriched_content, "metadata": metadata, - "salience": self._determine_salience(event.author) + "salience": self._determine_salience(event.author), } if self._config.enable_metadata_tags: @@ -138,7 +136,7 @@ async def add_session_to_memory(self, session: Session): async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } for event in session.events: @@ -155,16 +153,14 @@ async def add_session_to_memory(self, session: Session): "tags": memory_data.get("tags", []), "metadata": memory_data.get("metadata", {}), "salience": memory_data.get("salience", 0.5), - "user_id": session.user_id + "user_id": session.user_id, } - + response = await http_client.post( - f"{self._base_url}/memory/add", - json=payload, - headers=headers + f"{self._base_url}/memory/add", json=payload, headers=headers ) response.raise_for_status() - + memories_added += 1 logger.debug("Added memory for event %s", event.id) except httpx.HTTPStatusError as e: @@ -176,24 +172,24 @@ async def add_session_to_memory(self, session: Session): ) except httpx.RequestError as e: logger.error( - "Failed to add memory for event %s due to request error: %s", event.id, e + "Failed to add memory for event %s due to request error: %s", + event.id, + e, ) except Exception as e: - logger.error("Failed to add memory for event %s due to unexpected error: %s", event.id, e) + logger.error( + "Failed to add memory for event %s due to unexpected error: %s", + event.id, + e, + ) - logger.info( - "Added %d memories from session %s", memories_added, session.id - ) + logger.info("Added %d memories from session %s", memories_added, session.id) def _build_search_payload( self, app_name: str, user_id: str, query: str ) -> dict: """Build search payload for OpenMemory query API.""" - payload = { - "query": query, - "k": self._config.search_top_k, - "filter": {} - } + payload = {"query": query, "k": self._config.search_top_k, "filter": {}} payload["filter"]["user_id"] = user_id @@ -204,7 +200,7 @@ def _build_search_payload( def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: """Convert OpenMemory result to MemoryEntry. - + Extracts author and timestamp from enriched content format: [Author: user, Time: 2025-11-04T10:32:01] Content text """ @@ -213,28 +209,24 @@ def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: author = None timestamp = None clean_content = raw_content - + # Parse enriched content format to extract metadata - match = re.match(r'^\[([^\]]+)\]\s+(.*)', raw_content, re.DOTALL) + match = re.match(r"^\[([^\]]+)\]\s+(.*)", raw_content, re.DOTALL) if match: metadata_str = match.group(1) clean_content = match.group(2) - - author_match = re.search(r'Author:\s*([^,\]]+)', metadata_str) + + author_match = re.search(r"Author:\s*([^,\]]+)", metadata_str) if author_match: author = author_match.group(1).strip() - - time_match = re.search(r'Time:\s*([^,\]]+)', metadata_str) + + time_match = re.search(r"Time:\s*([^,\]]+)", metadata_str) if time_match: timestamp = time_match.group(1).strip() - + content = types.Content(parts=[types.Part(text=clean_content)]) - return MemoryEntry( - content=content, - author=author, - timestamp=timestamp - ) + return MemoryEntry(content=content, author=author, timestamp=timestamp) except (KeyError, ValueError) as e: logger.debug("Failed to convert result to MemoryEntry: %s", e) return None @@ -247,25 +239,27 @@ async def search_memory( try: search_payload = self._build_search_payload(app_name, user_id, query) memories = [] - + async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } - + logger.debug("Query payload: %s", search_payload) - + response = await http_client.post( f"{self._base_url}/memory/query", json=search_payload, - headers=headers + headers=headers, ) response.raise_for_status() result = response.json() - - logger.debug("Query returned %d matches", len(result.get("matches", []))) - + + logger.debug( + "Query returned %d matches", len(result.get("matches", [])) + ) + for match in result.get("matches", []): memory_entry = self._convert_to_memory_entry(match) if memory_entry: diff --git a/src/google/adk_community/memory/utils.py b/src/google/adk_community/memory/utils.py index 0b78206..48058ce 100644 --- a/src/google/adk_community/memory/utils.py +++ b/src/google/adk_community/memory/utils.py @@ -33,9 +33,8 @@ def extract_text_from_event(event) -> str: # Filter out thought parts and only extract text # This prevents metadata like thoughtSignature from being stored text_parts = [ - part.text - for part in event.content.parts + part.text + for part in event.content.parts if part.text and not part.thought ] return ' '.join(text_parts) - diff --git a/src/google/adk_community/sessions/redis_session_service.py b/src/google/adk_community/sessions/redis_session_service.py index bd8e289..075b342 100644 --- a/src/google/adk_community/sessions/redis_session_service.py +++ b/src/google/adk_community/sessions/redis_session_service.py @@ -17,23 +17,21 @@ import bisect import logging import time +from typing import Any +from typing import Optional import uuid -from typing import Any, Optional +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.adk.sessions.state import State import orjson import redis.asyncio as redis from redis.crc import key_slot from typing_extensions import override -from google.adk.events.event import Event -from google.adk.sessions.base_session_service import ( - BaseSessionService, - GetSessionConfig, - ListSessionsResponse, -) -from google.adk.sessions.session import Session -from google.adk.sessions.state import State - from .utils import _json_serializer logger = logging.getLogger("google_adk." + __name__) @@ -42,256 +40,261 @@ def _session_serializer(obj: Session) -> bytes: - """Serialize ADK Session to JSON bytes.""" - return orjson.dumps(obj.model_dump(), default=_json_serializer) + """Serialize ADK Session to JSON bytes.""" + return orjson.dumps(obj.model_dump(), default=_json_serializer) class RedisKeys: - """Helper to generate Redis keys consistently.""" + """Helper to generate Redis keys consistently.""" - @staticmethod - def session(session_id: str) -> str: - return f"session:{session_id}" + @staticmethod + def session(session_id: str) -> str: + return f"session:{session_id}" - @staticmethod - def user_sessions(app_name: str, user_id: str) -> str: - return f"{State.APP_PREFIX}:{app_name}:{user_id}" + @staticmethod + def user_sessions(app_name: str, user_id: str) -> str: + return f"{State.APP_PREFIX}:{app_name}:{user_id}" - @staticmethod - def app_state(app_name: str) -> str: - return f"{State.APP_PREFIX}{app_name}" + @staticmethod + def app_state(app_name: str) -> str: + return f"{State.APP_PREFIX}{app_name}" - @staticmethod - def user_state(app_name: str, user_id: str) -> str: - return f"{State.USER_PREFIX}{app_name}:{user_id}" + @staticmethod + def user_state(app_name: str, user_id: str) -> str: + return f"{State.USER_PREFIX}{app_name}:{user_id}" class RedisSessionService(BaseSessionService): - """A Redis-backed implementation of the session service.""" - - def __init__( - self, - host="localhost", - port=6379, - db=0, - uri=None, - cluster_uri=None, - expire=DEFAULT_EXPIRATION, - **kwargs, - ): - self.expire = expire - - if cluster_uri: - self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) - elif uri: - self.cache = redis.Redis.from_url(uri, **kwargs) - else: - self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) - - async def health_check(self) -> bool: - try: - await self.cache.ping() - return True - except redis.RedisError: - return False - - @override - async def create_session( - self, - *, - app_name: str, - user_id: str, - state: Optional[dict[str, Any]] = None, - session_id: Optional[str] = None, - ) -> Session: - session_id = ( - session_id.strip() - if session_id and session_id.strip() - else str(uuid.uuid4()) - ) - session = Session( - app_name=app_name, - user_id=user_id, - id=session_id, - state=state or {}, - last_update_time=time.time(), - ) - - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.sadd(user_sessions_key, session_id) - pipe.expire(user_sessions_key, self.expire) - pipe.set( - session_key, - _session_serializer(session), - ex=self.expire, + """A Redis-backed implementation of the session service.""" + + def __init__( + self, + host="localhost", + port=6379, + db=0, + uri=None, + cluster_uri=None, + expire=DEFAULT_EXPIRATION, + **kwargs, + ): + self.expire = expire + + if cluster_uri: + self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) + elif uri: + self.cache = redis.Redis.from_url(uri, **kwargs) + else: + self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) + + async def health_check(self) -> bool: + try: + await self.cache.ping() + return True + except redis.RedisError: + return False + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=state or {}, + last_update_time=time.time(), + ) + + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.sadd(user_sessions_key, session_id) + pipe.expire(user_sessions_key, self.expire) + pipe.set( + session_key, + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return await self._merge_state(app_name, user_id, session) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + session_key = RedisKeys.session(session_id) + raw_session = await self.cache.get(session_key) + if not raw_session: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + await self.cache.srem(user_sessions_key, session_id) + return None + + try: + session_dict = orjson.loads(raw_session) + session = Session.model_validate(session_dict) + except (orjson.JSONDecodeError, Exception) as e: + logger.error(f"Error decoding session {session_id}: {e}") + return None + + if config: + if config.num_recent_events: + session.events = session.events[-config.num_recent_events :] + if config.after_timestamp: + timestamps = [e.timestamp for e in session.events] + start_index = bisect.bisect_left(timestamps, config.after_timestamp) + session.events = session.events[start_index:] + + return await self._merge_state(app_name, user_id, session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + sessions = await self._load_sessions(app_name, user_id) + sessions_without_events = [] + + for session_data in sessions.values(): + session = Session.model_validate(session_data) + session.events = [] + session.state = {} + sessions_without_events.append(session) + + return ListSessionsResponse(sessions=sessions_without_events) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.srem(user_sessions_key, session_id) + pipe.delete(session_key) + await pipe.execute() + + @override + async def append_event(self, session: Session, event: Event) -> Event: + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + async with self.cache.pipeline(transaction=False) as pipe: + user_sessions_key = RedisKeys.user_sessions( + session.app_name, session.user_id + ) + pipe.expire(user_sessions_key, self.expire) + + if event.actions and event.actions.state_delta: + for key, value in event.actions.state_delta.items(): + if key.startswith(State.APP_PREFIX): + pipe.hset( + RedisKeys.app_state(session.app_name), + key.removeprefix(State.APP_PREFIX), + orjson.dumps(value), ) - await pipe.execute() - - return await self._merge_state(app_name, user_id, session) - - @override - async def get_session( - self, - *, - app_name: str, - user_id: str, - session_id: str, - config: Optional[GetSessionConfig] = None, - ) -> Optional[Session]: - session_key = RedisKeys.session(session_id) - raw_session = await self.cache.get(session_key) - if not raw_session: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - await self.cache.srem(user_sessions_key, session_id) - return None - - try: - session_dict = orjson.loads(raw_session) - session = Session.model_validate(session_dict) - except (orjson.JSONDecodeError, Exception) as e: - logger.error(f"Error decoding session {session_id}: {e}") - return None - - if config: - if config.num_recent_events: - session.events = session.events[-config.num_recent_events :] - if config.after_timestamp: - timestamps = [e.timestamp for e in session.events] - start_index = bisect.bisect_left(timestamps, config.after_timestamp) - session.events = session.events[start_index:] - - return await self._merge_state(app_name, user_id, session) - - @override - async def list_sessions( - self, *, app_name: str, user_id: str - ) -> ListSessionsResponse: - sessions = await self._load_sessions(app_name, user_id) - sessions_without_events = [] - - for session_data in sessions.values(): - session = Session.model_validate(session_data) - session.events = [] - session.state = {} - sessions_without_events.append(session) - - return ListSessionsResponse(sessions=sessions_without_events) - - @override - async def delete_session( - self, *, app_name: str, user_id: str, session_id: str - ) -> None: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.srem(user_sessions_key, session_id) - pipe.delete(session_key) - await pipe.execute() - - @override - async def append_event(self, session: Session, event: Event) -> Event: - await super().append_event(session=session, event=event) - session.last_update_time = event.timestamp - - async with self.cache.pipeline(transaction=False) as pipe: - user_sessions_key = RedisKeys.user_sessions( - session.app_name, session.user_id - ) - pipe.expire(user_sessions_key, self.expire) - - if event.actions and event.actions.state_delta: - for key, value in event.actions.state_delta.items(): - if key.startswith(State.APP_PREFIX): - pipe.hset( - RedisKeys.app_state(session.app_name), - key.removeprefix(State.APP_PREFIX), - orjson.dumps(value), - ) - if key.startswith(State.USER_PREFIX): - pipe.hset( - RedisKeys.user_state(session.app_name, session.user_id), - key.removeprefix(State.USER_PREFIX), - orjson.dumps(value), - ) - - pipe.set( - RedisKeys.session(session.id), - _session_serializer(session), - ex=self.expire, - ) - await pipe.execute() - - return event - - async def _merge_state( - self, app_name: str, user_id: str, session: Session - ) -> Session: - app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) - for k, v in app_state.items(): - session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) - - user_state = await self.cache.hgetall(RedisKeys.user_state(app_name, user_id)) - for k, v in user_state.items(): - session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) - - return session - - async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]: - key = RedisKeys.user_sessions(app_name, user_id) - try: - session_ids_bytes = await self.cache.smembers(key) - if not session_ids_bytes: - return {} - - session_ids = [s.decode() for s in session_ids_bytes] - session_keys = [RedisKeys.session(sid) for sid in session_ids] - - # Group by slot for Redis Cluster - slot_groups: dict[int, list[str]] = {} - for k in session_keys: - slot = key_slot(k.encode()) - slot_groups.setdefault(slot, []).append(k) - - async def fetch_group(keys: list[str]): - async with self.cache.pipeline(transaction=False) as pipe: - for k in keys: - pipe.get(k) - return await pipe.execute() - - results_per_group = await asyncio.gather( - *(fetch_group(keys) for keys in slot_groups.values()) + if key.startswith(State.USER_PREFIX): + pipe.hset( + RedisKeys.user_state(session.app_name, session.user_id), + key.removeprefix(State.USER_PREFIX), + orjson.dumps(value), ) - raw_sessions = [] - for group_keys, group_results in zip( - slot_groups.values(), results_per_group - ): - raw_sessions.extend(zip(group_keys, group_results)) - - sessions = {} - sessions_to_cleanup = [] - for key_name, raw_session in raw_sessions: - session_id = key_name.split(":", 1)[1] - if raw_session: - try: - sessions[session_id] = orjson.loads(raw_session) - except orjson.JSONDecodeError as e: - logger.error(f"Error decoding session {session_id}: {e}") - else: - logger.warning( - "Session ID %s found in user set but session data is missing. Cleaning up.", - session_id, - ) - sessions_to_cleanup.append(session_id) - - if sessions_to_cleanup: - await self.cache.srem(key, *sessions_to_cleanup) - - return sessions - except redis.RedisError as e: - logger.error(f"Error loading sessions for {user_id}: {e}") - return {} + pipe.set( + RedisKeys.session(session.id), + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return event + + async def _merge_state( + self, app_name: str, user_id: str, session: Session + ) -> Session: + app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) + for k, v in app_state.items(): + session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) + + user_state = await self.cache.hgetall( + RedisKeys.user_state(app_name, user_id) + ) + for k, v in user_state.items(): + session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) + + return session + + async def _load_sessions( + self, app_name: str, user_id: str + ) -> dict[str, dict]: + key = RedisKeys.user_sessions(app_name, user_id) + try: + session_ids_bytes = await self.cache.smembers(key) + if not session_ids_bytes: + return {} + + session_ids = [s.decode() for s in session_ids_bytes] + session_keys = [RedisKeys.session(sid) for sid in session_ids] + + # Group by slot for Redis Cluster + slot_groups: dict[int, list[str]] = {} + for k in session_keys: + slot = key_slot(k.encode()) + slot_groups.setdefault(slot, []).append(k) + + async def fetch_group(keys: list[str]): + async with self.cache.pipeline(transaction=False) as pipe: + for k in keys: + pipe.get(k) + return await pipe.execute() + + results_per_group = await asyncio.gather( + *(fetch_group(keys) for keys in slot_groups.values()) + ) + + raw_sessions = [] + for group_keys, group_results in zip( + slot_groups.values(), results_per_group + ): + raw_sessions.extend(zip(group_keys, group_results)) + + sessions = {} + sessions_to_cleanup = [] + for key_name, raw_session in raw_sessions: + session_id = key_name.split(":", 1)[1] + if raw_session: + try: + sessions[session_id] = orjson.loads(raw_session) + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding session {session_id}: {e}") + else: + logger.warning( + "Session ID %s found in user set but session data is missing." + " Cleaning up.", + session_id, + ) + sessions_to_cleanup.append(session_id) + + if sessions_to_cleanup: + await self.cache.srem(key, *sessions_to_cleanup) + + return sessions + except redis.RedisError as e: + logger.error(f"Error loading sessions for {user_id}: {e}") + return {} diff --git a/src/google/adk_community/sessions/utils.py b/src/google/adk_community/sessions/utils.py index bc53d2b..132c773 100644 --- a/src/google/adk_community/sessions/utils.py +++ b/src/google/adk_community/sessions/utils.py @@ -20,18 +20,18 @@ def _json_serializer(obj): - """Fallback serializer to handle non-JSON-compatible types.""" - if isinstance(obj, set): - return list(obj) - if isinstance(obj, bytes): - try: - return base64.b64encode(obj).decode("ascii") - except Exception: - return repr(obj) - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - if isinstance(obj, Decimal): - return float(obj) + """Fallback serializer to handle non-JSON-compatible types.""" + if isinstance(obj, set): + return list(obj) + if isinstance(obj, bytes): + try: + return base64.b64encode(obj).decode("ascii") + except Exception: + return repr(obj) + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, uuid.UUID): return str(obj) + if isinstance(obj, Decimal): + return float(obj) + return str(obj) diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py new file mode 100644 index 0000000..85ce859 --- /dev/null +++ b/src/google/adk_community/tools/__init__.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools module for ADK Community integrations. + +This module uses lazy loading to provide helpful error messages when +optional dependencies are not installed. +""" + +__all__ = [ + "BaseRedisSearchTool", + "VectorizedSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", + "RedisRangeQueryConfig", + "RedisTextQueryConfig", +] + +# Redis tool names for lazy loading +_REDIS_TOOLS = { + "BaseRedisSearchTool", + "VectorizedSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", +} + +# Redis config names for lazy loading +_REDIS_CONFIGS = { + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", + "RedisRangeQueryConfig", + "RedisTextQueryConfig", +} + + +def __getattr__(name: str): + """Lazy load tools to provide helpful error messages.""" + if name in _REDIS_TOOLS: + try: + from .redis import BaseRedisSearchTool + from .redis import RedisHybridSearchTool + from .redis import RedisRangeSearchTool + from .redis import RedisTextSearchTool + from .redis import RedisVectorSearchTool + from .redis import VectorizedSearchTool + + globals().update({ + "BaseRedisSearchTool": BaseRedisSearchTool, + "VectorizedSearchTool": VectorizedSearchTool, + "RedisVectorSearchTool": RedisVectorSearchTool, + "RedisHybridSearchTool": RedisHybridSearchTool, + "RedisRangeSearchTool": RedisRangeSearchTool, + "RedisTextSearchTool": RedisTextSearchTool, + }) + return globals()[name] + except ImportError as e: + raise ImportError( + f"{name} requires redisvl. " + "Install with: pip install google-adk-community[redis-vl]" + ) from e + if name in _REDIS_CONFIGS: + from .redis import RedisAggregatedHybridQueryConfig + from .redis import RedisHybridQueryConfig + from .redis import RedisRangeQueryConfig + from .redis import RedisTextQueryConfig + from .redis import RedisVectorQueryConfig + + globals().update({ + "RedisVectorQueryConfig": RedisVectorQueryConfig, + "RedisHybridQueryConfig": RedisHybridQueryConfig, + "RedisAggregatedHybridQueryConfig": RedisAggregatedHybridQueryConfig, + "RedisRangeQueryConfig": RedisRangeQueryConfig, + "RedisTextQueryConfig": RedisTextQueryConfig, + }) + return globals()[name] + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/google/adk_community/tools/redis/__init__.py b/src/google/adk_community/tools/redis/__init__.py new file mode 100644 index 0000000..9071f14 --- /dev/null +++ b/src/google/adk_community/tools/redis/__init__.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis tools for ADK Community using RedisVL. + +This module provides tools for Redis-based search operations: + +- `RedisVectorSearchTool`: KNN vector similarity search +- `RedisHybridSearchTool`: Combined vector + BM25 text search +- `RedisRangeSearchTool`: Distance threshold-based vector search +- `RedisTextSearchTool`: Full-text BM25 keyword search + +Configuration classes for query parameters: + +- `RedisVectorQueryConfig`: Configuration for vector search queries +- `RedisHybridQueryConfig`: Configuration for native hybrid search (RedisVL >= 0.13.0) +- `RedisAggregatedHybridQueryConfig`: Configuration for client-side hybrid (older versions) +- `RedisRangeQueryConfig`: Configuration for range search queries +- `RedisTextQueryConfig`: Configuration for text search queries + +Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from google.adk_community.tools.redis import ( + RedisVectorSearchTool, + RedisVectorQueryConfig, + ) + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + config = RedisVectorQueryConfig(num_results=5, ef_runtime=100) + tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + ) + ``` +""" + +try: + from .base_search_tool import BaseRedisSearchTool + from .base_search_tool import VectorizedSearchTool + from .config import RedisAggregatedHybridQueryConfig + from .config import RedisHybridQueryConfig + from .config import RedisRangeQueryConfig + from .config import RedisTextQueryConfig + from .config import RedisVectorQueryConfig + from .hybrid_search_tool import RedisHybridSearchTool + from .range_search_tool import RedisRangeSearchTool + from .text_search_tool import RedisTextSearchTool + from .vector_search_tool import RedisVectorSearchTool +except ImportError as e: + raise ImportError( + "Redis tools require redisvl. " + "Install with: pip install google-adk-community[redis-vl]" + ) from e + +__all__ = [ + "BaseRedisSearchTool", + "VectorizedSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", + "RedisRangeQueryConfig", + "RedisTextQueryConfig", +] diff --git a/src/google/adk_community/tools/redis/base_search_tool.py b/src/google/adk_community/tools/redis/base_search_tool.py new file mode 100644 index 0000000..bdb2170 --- /dev/null +++ b/src/google/adk_community/tools/redis/base_search_tool.py @@ -0,0 +1,223 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes for Redis search tools using RedisVL.""" + +from __future__ import annotations + +from abc import abstractmethod +import asyncio +from typing import Any +from typing import Callable +from typing import Coroutine +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from google.adk.tools import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import BaseVectorizer + + +class BaseRedisSearchTool(BaseTool): + """Base class for ALL Redis search tools using RedisVL. + + This class provides common functionality shared by all Redis search tools: + - Index management (sync and async) + - Query execution with proper async handling + - Standard response formatting + - Error handling + + Subclasses should use `_run_search` to execute queries with consistent + error handling and response formatting. + """ + + def __init__( + self, + *, + name: str, + description: str, + index: Union[SearchIndex, AsyncSearchIndex], + return_fields: Optional[List[str]] = None, + ): + """Initialize the base Redis search tool. + + Args: + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + return_fields: Optional list of fields to return in results. + """ + super().__init__(name=name, description=description) + self._index = index + self._return_fields = return_fields + self._is_async_index = isinstance(index, AsyncSearchIndex) + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM. + + Returns a simple interface with just a query parameter. + Subclasses can override to add additional parameters. + """ + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + }, + required=["query"], + ), + ) + + async def _execute_query(self, query: Any) -> List[Dict[str, Any]]: + """Execute a RedisVL query and return formatted results. + + Args: + query: A RedisVL query object (VectorQuery, TextQuery, etc.) + + Returns: + List of result dictionaries. + """ + if self._is_async_index: + results = await self._index.query(query) + else: + # Run sync query in thread pool to avoid blocking + results = await asyncio.to_thread(self._index.query, query) + + return [dict(r) for r in results] if results else [] + + async def _run_search( + self, + args: Dict[str, Any], + build_query_fn: Callable[[str, Dict[str, Any]], Coroutine[Any, Any, Any]], + ) -> Dict[str, Any]: + """Execute a search with consistent validation, execution, and formatting. + + This is a template method that handles: + - Query text validation + - Query building (via the provided async function) + - Query execution + - Response formatting + - Error handling + + Args: + args: Arguments from the LLM, must include 'query'. + build_query_fn: Async function that takes (query_text, args) and + returns a RedisVL query object. + + Returns: + A dictionary with status, count, and results (or error). + """ + query_text = args.get("query", "") + + if not query_text: + return {"status": "error", "error": "Query text is required."} + + try: + # Build the query using the provided function + redisvl_query = await build_query_fn(query_text, args) + + # Execute and format results + results = await self._execute_query(redisvl_query) + + return { + "status": "success", + "count": len(results), + "results": results, + } + + except Exception as e: + return {"status": "error", "error": str(e)} + + +class VectorizedSearchTool(BaseRedisSearchTool): + """Base class for Redis search tools that require vector embeddings. + + This class extends BaseRedisSearchTool with: + - Required vectorizer for embedding queries + - Abstract _build_query method for subclasses to implement + + Use this as the base class for vector-based search tools like + VectorSearchTool, HybridSearchTool, and RangeSearchTool. + """ + + def __init__( + self, + *, + name: str, + description: str, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + return_fields: Optional[List[str]] = None, + ): + """Initialize the vectorized search tool. + + Args: + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries (required). + return_fields: Optional list of fields to return in results. + """ + super().__init__( + name=name, + description=description, + index=index, + return_fields=return_fields, + ) + self._vectorizer = vectorizer + + @abstractmethod + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> Any: + """Build the RedisVL query object. + + Args: + query_text: The original query text from the user. + embedding: The vector embedding of the query text. + **kwargs: Additional parameters from the LLM call. + + Returns: + A RedisVL query object (VectorQuery, HybridQuery, etc.) + """ + pass + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> Dict[str, Any]: + """Execute the vector-based search query. + + Args: + args: Arguments from the LLM, must include 'query'. + tool_context: The tool execution context. + + Returns: + A dictionary with status, count, and results. + """ + + async def build_query_fn(query_text: str, args: Dict[str, Any]) -> Any: + embedding = await self._vectorizer.aembed(query_text) + return self._build_query(query_text, embedding, **args) + + return await self._run_search(args, build_query_fn) diff --git a/src/google/adk_community/tools/redis/config.py b/src/google/adk_community/tools/redis/config.py new file mode 100644 index 0000000..465b9da --- /dev/null +++ b/src/google/adk_community/tools/redis/config.py @@ -0,0 +1,470 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for Redis search tools. + +This module provides Pydantic configuration classes for Redis search tools, +following the pattern established by BigQueryToolConfig in the upstream +google/adk-python repository and OpenMemoryServiceConfig in this repository. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +# Type alias for sort specification +SortSpec = Optional[ + Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] +] + + +class RedisVectorQueryConfig(BaseModel): + """Configuration for Redis vector similarity search queries. + + This config groups all query-specific parameters for VectorQuery, + separating them from tool-level concerns like index and vectorizer. + + Attributes: + vector_field_name: Name of the vector field in the index. + num_results: Number of results to return (default: 10). + dtype: Data type of the vector (default: "float32"). + return_score: Whether to return vector distance scores. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + normalize_vector_distance: Convert distance to 0-1 similarity score. + hybrid_policy: Filter application policy - "BATCHES" or "ADHOC_BF". + batch_size: Batch size when hybrid_policy is "BATCHES". + ef_runtime: HNSW exploration factor at query time. + epsilon: Range search approximation factor for HNSW/SVS-VAMANA. + search_window_size: SVS-VAMANA search window size. + use_search_history: SVS-VAMANA history mode - "OFF", "ON", or "AUTO". + search_buffer_capacity: SVS-VAMANA 2-level compression tuning. + """ + + model_config = ConfigDict(extra="forbid") + + # Core query parameters + vector_field_name: str = Field(default="embedding") + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + normalize_vector_distance: bool = Field(default=False) + + # Version-dependent parameters (RedisVL 0.13.2+) + # These are excluded from query kwargs when None for backward compatibility + hybrid_policy: Optional[str] = Field(default=None) + batch_size: Optional[int] = Field(default=None, ge=1) + ef_runtime: Optional[int] = Field(default=None, ge=1) + epsilon: Optional[float] = Field(default=None, ge=0.0) + search_window_size: Optional[int] = Field(default=None, ge=1) + use_search_history: Optional[str] = Field(default=None) + search_buffer_capacity: Optional[int] = Field(default=None, ge=1) + + def to_query_kwargs( + self, vector: List[float], filter_expression: Optional[Any] = None + ) -> dict[str, Any]: + """Convert config to VectorQuery kwargs, excluding None version-dependent params. + + Args: + vector: The query vector embedding. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for VectorQuery constructor. + """ + # Core parameters always included + kwargs: dict[str, Any] = { + "vector": vector, + "vector_field_name": self.vector_field_name, + "num_results": self.num_results, + "dtype": self.dtype, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "normalize_vector_distance": self.normalize_vector_distance, + "filter_expression": filter_expression, + } + + # Version-dependent parameters: only include if not None + version_dependent = { + "hybrid_policy": self.hybrid_policy, + "batch_size": self.batch_size, + "ef_runtime": self.ef_runtime, + "epsilon": self.epsilon, + "search_window_size": self.search_window_size, + "use_search_history": self.use_search_history, + "search_buffer_capacity": self.search_buffer_capacity, + } + for key, value in version_dependent.items(): + if value is not None: + kwargs[key] = value + + return kwargs + + +class RedisRangeQueryConfig(BaseModel): + """Configuration for Redis vector range search queries. + + Range search finds all documents within a specified distance threshold + from the query vector, rather than returning a fixed number of results. + + Attributes: + vector_field_name: Name of the vector field in the index. + distance_threshold: Maximum distance for results (default: 0.2). + num_results: Maximum number of results to return. + dtype: Data type of the vector (default: "float32"). + return_score: Whether to return vector distance scores. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + normalize_vector_distance: Convert distance to 0-1 similarity score. + epsilon: Range search approximation factor for HNSW/SVS-VAMANA. + """ + + model_config = ConfigDict(extra="forbid") + + vector_field_name: str = Field(default="embedding") + distance_threshold: float = Field(default=0.2, ge=0.0) + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + normalize_vector_distance: bool = Field(default=False) + + # Version-dependent parameter + epsilon: Optional[float] = Field(default=None, ge=0.0) + + def to_query_kwargs( + self, vector: List[float], filter_expression: Optional[Any] = None + ) -> dict[str, Any]: + """Convert config to VectorRangeQuery kwargs. + + Args: + vector: The query vector embedding. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for VectorRangeQuery constructor. + """ + kwargs: dict[str, Any] = { + "vector": vector, + "vector_field_name": self.vector_field_name, + "distance_threshold": self.distance_threshold, + "num_results": self.num_results, + "dtype": self.dtype, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "normalize_vector_distance": self.normalize_vector_distance, + "filter_expression": filter_expression, + } + + # Version-dependent: only include if not None + if self.epsilon is not None: + kwargs["epsilon"] = self.epsilon + + return kwargs + + +class RedisTextQueryConfig(BaseModel): + """Configuration for Redis full-text search queries. + + This config groups all query-specific parameters for TextQuery, + using BM25 scoring for keyword-based search. + + Attributes: + text_field_name: Name of the text field to search. + text_scorer: Text scoring algorithm (default: "BM25STD"). + num_results: Number of results to return (default: 10). + return_score: Whether to return the text score. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + stopwords: Stopwords to remove from query (default: "english"). + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + text_field_name: str = Field(default="content") + text_scorer: str = Field(default="BM25STD") + num_results: int = Field(default=10, ge=1) + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + + def to_query_kwargs( + self, + text: str, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to TextQuery kwargs. + + Args: + text: The query text for BM25 matching. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for TextQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "text_scorer": self.text_scorer, + "num_results": self.num_results, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "stopwords": self.stopwords, + "return_fields": return_fields, + "filter_expression": filter_expression, + } + + +class RedisHybridQueryConfig(BaseModel): + """Configuration for native Redis hybrid search queries. + + Uses Redis's native FT.HYBRID command for server-side hybrid search + combining semantic vector similarity with keyword-based BM25 text matching. + + Requirements: + - RedisVL >= 0.13.0 + - Redis >= 8.4.0 + - redis-py >= 7.1.0 + + For older Redis/RedisVL versions, use RedisAggregatedHybridQueryConfig instead. + + Example: + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisHybridQueryConfig, + ) + + config = RedisHybridQueryConfig( + text_field_name="content", + combination_method="LINEAR", + linear_alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool(index=index, vectorizer=vectorizer, config=config) + ``` + + Attributes: + text_field_name: Name of the text field for BM25 search. + vector_field_name: Name of the vector field for similarity search. + vector_param_name: Name of the parameter substitution for vector blob. + text_scorer: Text scoring algorithm (default: "BM25STD"). + yield_text_score_as: Field name to yield the text score as. + vector_search_method: Vector search method - "KNN" or "RANGE". + knn_ef_runtime: Exploration factor for HNSW when using KNN. + range_radius: Search radius when using RANGE vector search. + range_epsilon: Epsilon for RANGE search accuracy. + yield_vsim_score_as: Field name to yield vector similarity score as. + combination_method: Score combination method - "RRF" or "LINEAR". + linear_alpha: Weight of text score when using LINEAR (0.0-1.0). + rrf_window: Window size for RRF combination. + rrf_constant: Constant for RRF combination. + yield_combined_score_as: Field name to yield combined score as. + num_results: Number of results to return. + dtype: Data type of the vector. + stopwords: Stopwords to remove from query. + text_weights: Optional field weights for text scoring. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + # Text search parameters + text_field_name: str = Field(default="content") + text_scorer: str = Field(default="BM25STD") + yield_text_score_as: Optional[str] = Field(default=None) + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + text_weights: Optional[Dict[str, float]] = Field(default=None) + + # Vector search parameters + vector_field_name: str = Field(default="embedding") + vector_param_name: str = Field(default="vector") + vector_search_method: Optional[str] = Field(default=None) + knn_ef_runtime: int = Field(default=10, ge=1) + range_radius: Optional[float] = Field(default=None) + range_epsilon: float = Field(default=0.01, ge=0.0) + yield_vsim_score_as: Optional[str] = Field(default=None) + dtype: str = Field(default="float32") + + # Score combination parameters + combination_method: Optional[str] = Field(default=None) + linear_alpha: float = Field(default=0.3, ge=0.0, le=1.0) + rrf_window: int = Field(default=20, ge=1) + rrf_constant: int = Field(default=60, ge=1) + yield_combined_score_as: Optional[str] = Field(default=None) + + # Result parameters + num_results: int = Field(default=10, ge=1) + + def to_query_kwargs( + self, + text: str, + vector: List[float], + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to native HybridQuery kwargs. + + Args: + text: The query text for BM25 matching. + vector: The query vector embedding. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for HybridQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "vector": vector, + "vector_field_name": self.vector_field_name, + "vector_param_name": self.vector_param_name, + "text_scorer": self.text_scorer, + "yield_text_score_as": self.yield_text_score_as, + "vector_search_method": self.vector_search_method, + "knn_ef_runtime": self.knn_ef_runtime, + "range_radius": self.range_radius, + "range_epsilon": self.range_epsilon, + "yield_vsim_score_as": self.yield_vsim_score_as, + "filter_expression": filter_expression, + "combination_method": self.combination_method, + "rrf_window": self.rrf_window, + "rrf_constant": self.rrf_constant, + "linear_alpha": self.linear_alpha, + "yield_combined_score_as": self.yield_combined_score_as, + "dtype": self.dtype, + "num_results": self.num_results, + "return_fields": return_fields, + "stopwords": self.stopwords, + "text_weights": self.text_weights, + } + + +class RedisAggregatedHybridQueryConfig(BaseModel): + """Configuration for aggregated (client-side) Redis hybrid search queries. + + .. deprecated:: + This config is for older Redis/RedisVL versions. For newer setups + (RedisVL >= 0.13.0, Redis >= 8.4.0), prefer RedisHybridQueryConfig + which uses native server-side hybrid search for better performance. + + This config uses AggregateHybridQuery which performs client-side hybrid + search using FT.AGGREGATE with weighted score combination. It works with + any Redis version that has RediSearch installed. + + Recommended for: + - Redis < 8.4.0 + - RedisVL < 0.13.0 + - Environments where native FT.HYBRID is not available + + Example: + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisAggregatedHybridQueryConfig, + ) + + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool(index=index, vectorizer=vectorizer, config=config) + ``` + + Attributes: + text_field_name: Name of the text field for BM25 search. + vector_field_name: Name of the vector field for similarity search. + text_scorer: Text scoring algorithm (default: "BM25STD"). + alpha: Weight for text score (default: 0.7). Higher values favor + text matching over vector similarity. Combined score is: + alpha * text_score + (1 - alpha) * vector_score + num_results: Number of results to return. + dtype: Data type of the vector. + stopwords: Stopwords to remove from query. + dialect: RediSearch query dialect version. + text_weights: Optional field weights for text scoring. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + text_field_name: str = Field(default="content") + vector_field_name: str = Field(default="embedding") + text_scorer: str = Field(default="BM25STD") + alpha: float = Field(default=0.7, ge=0.0, le=1.0) + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + dialect: int = Field(default=2, ge=1) + text_weights: Optional[Dict[str, float]] = Field(default=None) + + def to_query_kwargs( + self, + text: str, + vector: List[float], + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to AggregateHybridQuery kwargs. + + Args: + text: The query text for BM25 matching. + vector: The query vector embedding. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for AggregateHybridQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "vector": vector, + "vector_field_name": self.vector_field_name, + "text_scorer": self.text_scorer, + "alpha": self.alpha, + "dtype": self.dtype, + "num_results": self.num_results, + "return_fields": return_fields, + "stopwords": self.stopwords, + "dialect": self.dialect, + "text_weights": self.text_weights, + "filter_expression": filter_expression, + } diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py new file mode 100644 index 0000000..28ebbf5 --- /dev/null +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -0,0 +1,341 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis hybrid search tool combining vector similarity and BM25 text search.""" + +from __future__ import annotations + +import logging +from typing import Any +from typing import List +from typing import Optional +from typing import Union +import warnings + +from google.genai import types +from packaging.version import parse +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import VectorizedSearchTool +from .config import RedisAggregatedHybridQueryConfig +from .config import RedisHybridQueryConfig + +# Minimum RedisVL version required for native FT.HYBRID support +_MIN_NATIVE_HYBRID_VERSION = "0.13.0" +# Minimum Redis server version required for native FT.HYBRID support +_MIN_REDIS_SERVER_VERSION = "8.4.0" + +logger = logging.getLogger(__name__) + + +def _get_redisvl_version() -> str: + """Get the installed RedisVL version string. + + Returns: + Version string (e.g., "0.13.0") or "0.0.0" if not available. + """ + try: + import redisvl + + version = getattr(redisvl, "__version__", None) + if version is None: + logger.debug("redisvl.__version__ not found, assuming 0.0.0") + return "0.0.0" + return version + except ImportError: + # This shouldn't normally happen due to module-level guard in __init__.py + logger.debug("redisvl not importable, assuming version 0.0.0") + return "0.0.0" + + +def _get_redis_server_version( + index: Union[SearchIndex, AsyncSearchIndex], +) -> str: + """Get the Redis server version from the index's client. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex. + + Returns: + Version string (e.g., "8.4.0") or "0.0.0" if not available. + """ + try: + # For sync index, use _redis_client to trigger lazy connection if needed + if isinstance(index, SearchIndex): + client = index._redis_client + else: + # For async index, we can only use the public property + # which may be None if not yet connected + client = index.client + + if client is None: + logger.warning( + "Redis client not available. For AsyncSearchIndex, ensure the " + "index has been used (e.g., await index.create()) before " + "creating RedisHybridSearchTool, or native hybrid support " + "detection will be skipped." + ) + return "0.0.0" + + info = client.info("server") + return info.get("redis_version", "0.0.0") + except Exception as e: + logger.warning( + "Could not determine Redis server version: %s. " + "Native hybrid search may not be available.", + e, + ) + return "0.0.0" + + +def _supports_native_hybrid( + index: Union[SearchIndex, AsyncSearchIndex], +) -> bool: + """Check if native HybridQuery is supported. + + Native hybrid search requires both: + - RedisVL >= 0.13.0 + - Redis server >= 8.4.0 + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to check. + + Returns: + True if both version requirements are met, False otherwise. + """ + # Check redisvl version + redisvl_version = _get_redisvl_version() + try: + if parse(redisvl_version) < parse(_MIN_NATIVE_HYBRID_VERSION): + logger.debug( + "Native hybrid not supported: RedisVL %s < %s", + redisvl_version, + _MIN_NATIVE_HYBRID_VERSION, + ) + return False + except Exception as e: + logger.warning( + "Could not parse redisvl version '%s': %s. " + "Native hybrid search will be disabled.", + redisvl_version, + e, + ) + return False + + # Check Redis server version + redis_version = _get_redis_server_version(index) + try: + if parse(redis_version) < parse(_MIN_REDIS_SERVER_VERSION): + logger.debug( + "Native hybrid not supported: Redis server %s < %s", + redis_version, + _MIN_REDIS_SERVER_VERSION, + ) + return False + except Exception as e: + logger.warning( + "Could not parse Redis server version '%s': %s. " + "Native hybrid search will be disabled.", + redis_version, + e, + ) + return False + + return True + + +class RedisHybridSearchTool(VectorizedSearchTool): + """Hybrid search tool combining vector similarity and BM25 text search. + + This tool performs a hybrid search that combines semantic vector similarity + with keyword-based BM25 text matching. It automatically detects the installed + RedisVL version and uses the appropriate implementation: + + - **RedisVL >= 0.13.0**: Uses native FT.HYBRID command (server-side fusion) + with RedisHybridQueryConfig. Requires Redis >= 8.4.0. + - **RedisVL < 0.13.0**: Uses AggregateHybridQuery (client-side fusion) + with RedisAggregatedHybridQueryConfig. Works with any Redis version. + + Example (native mode - RedisVL >= 0.13.0): + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisHybridQueryConfig, + ) + + config = RedisHybridQueryConfig( + text_field_name="content", + combination_method="LINEAR", + linear_alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + ) + ``` + + Example (aggregate mode - older versions): + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisAggregatedHybridQueryConfig, + ) + + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + ) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + config: Optional[ + Union[RedisHybridQueryConfig, RedisAggregatedHybridQueryConfig] + ] = None, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + name: str = "redis_hybrid_search", + description: str = "Search using both semantic similarity and keyword matching.", + ): + """Initialize the hybrid search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries. + config: Configuration for query parameters. Can be either: + - RedisHybridQueryConfig: For native FT.HYBRID (RedisVL >= 0.13.0) + - RedisAggregatedHybridQueryConfig: For client-side hybrid (older) + If None, auto-detects based on installed RedisVL version. + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + + Raises: + ValueError: If RedisHybridQueryConfig is used with RedisVL < 0.13.0. + + Warns: + DeprecationWarning: If RedisAggregatedHybridQueryConfig is used when + native hybrid is available (RedisVL >= 0.13.0). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + + self._supports_native = _supports_native_hybrid(index) + + # Auto-detect config if not provided + if config is None: + if self._supports_native: + config = RedisHybridQueryConfig() + else: + config = RedisAggregatedHybridQueryConfig() + + # Validate config compatibility with installed version + self._use_native = isinstance(config, RedisHybridQueryConfig) + + if self._use_native and not self._supports_native: + raise ValueError( + "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >= " + f"8.4.0. Installed RedisVL version: {_get_redisvl_version()}, " + f"Redis server version: {_get_redis_server_version(index)}. " + "Use RedisAggregatedHybridQueryConfig for older versions." + ) + + if not self._use_native and self._supports_native: + warnings.warn( + "RedisAggregatedHybridQueryConfig is deprecated for RedisVL >=" + " 0.13.0. Consider using RedisHybridQueryConfig for native FT.HYBRID" + " support with better performance. RedisAggregatedHybridQueryConfig" + " will continue to work but uses client-side score combination.", + DeprecationWarning, + stacklevel=2, + ) + + self._config = config + self._filter_expression = filter_expression + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._config.num_results})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> Any: + """Build a query for combined vector + text search. + + Args: + query_text: The original query text for BM25 matching. + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., num_results). + + Returns: + A HybridQuery or AggregateHybridQuery configured for hybrid search. + """ + # Allow runtime override of num_results + num_results = kwargs.get("num_results", self._config.num_results) + + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( + text=query_text, + vector=embedding, + return_fields=self._return_fields, + filter_expression=self._filter_expression, + ) + query_kwargs["num_results"] = num_results + + if self._use_native: + from redisvl.query import HybridQuery + + return HybridQuery(**query_kwargs) + else: + from redisvl.query import AggregateHybridQuery + + return AggregateHybridQuery(**query_kwargs) diff --git a/src/google/adk_community/tools/redis/range_search_tool.py b/src/google/adk_community/tools/redis/range_search_tool.py new file mode 100644 index 0000000..dac714e --- /dev/null +++ b/src/google/adk_community/tools/redis/range_search_tool.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis vector range search tool using distance threshold.""" + +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +from typing import Union + +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import VectorRangeQuery +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import VectorizedSearchTool +from .config import RedisRangeQueryConfig + + +class RedisRangeSearchTool(VectorizedSearchTool): + """Vector range search tool using distance threshold. + + This tool finds all documents within a specified distance threshold + from the query vector. Unlike KNN search which returns a fixed number + of results, range search returns all documents that are "close enough" + based on the threshold. + + Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from google.adk_community.tools.redis import ( + RedisRangeSearchTool, + RedisRangeQueryConfig, + ) + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + # Using config object (recommended) + config = RedisRangeQueryConfig( + distance_threshold=0.3, # Only return docs within 0.3 distance + ) + tool = RedisRangeSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + return_fields=["title", "content"], + ) + + agent = Agent(model="gemini-2.5-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + config: Optional[RedisRangeQueryConfig] = None, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + name: str = "redis_range_search", + description: str = "Find all documents within a similarity threshold.", + ): + """Initialize the range search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries. + config: Configuration for query parameters. If None, uses defaults. + See RedisRangeQueryConfig for available options including + distance_threshold, vector_field_name, and epsilon. + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + self._config = config or RedisRangeQueryConfig() + self._filter_expression = filter_expression + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "distance_threshold": types.Schema( + type=types.Type.NUMBER, + description=( + "Max distance threshold (default:" + f" {self._config.distance_threshold})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> VectorRangeQuery: + """Build a VectorRangeQuery for distance-based search. + + Args: + query_text: The original query text (unused for range search). + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., distance_threshold). + + Returns: + A VectorRangeQuery configured for range search. + """ + # Allow runtime override of distance_threshold + distance_threshold = kwargs.get( + "distance_threshold", self._config.distance_threshold + ) + + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( + vector=embedding, + filter_expression=self._filter_expression, + ) + query_kwargs["return_fields"] = self._return_fields + query_kwargs["distance_threshold"] = distance_threshold + + return VectorRangeQuery(**query_kwargs) diff --git a/src/google/adk_community/tools/redis/text_search_tool.py b/src/google/adk_community/tools/redis/text_search_tool.py new file mode 100644 index 0000000..8d719b6 --- /dev/null +++ b/src/google/adk_community/tools/redis/text_search_tool.py @@ -0,0 +1,149 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis full-text search tool using BM25.""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import TextQuery + +from .base_search_tool import BaseRedisSearchTool +from .config import RedisTextQueryConfig + + +class RedisTextSearchTool(BaseRedisSearchTool): + """Full-text search tool using BM25 scoring. + + This tool performs keyword-based full-text search using BM25 scoring. + Unlike vector search, it doesn't require embeddings - it matches + documents based on keyword relevance. + + Example: + ```python + from redisvl.index import SearchIndex + from google.adk_community.tools.redis import ( + RedisTextSearchTool, + RedisTextQueryConfig, + ) + + index = SearchIndex.from_yaml("schema.yaml") + + # Using config object (recommended) + config = RedisTextQueryConfig( + text_field_name="content", + num_results=10, + text_scorer="BM25STD", + ) + tool = RedisTextSearchTool( + index=index, + config=config, + return_fields=["title", "content"], + ) + + agent = Agent(model="gemini-2.5-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + config: Optional[RedisTextQueryConfig] = None, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + name: str = "redis_text_search", + description: str = "Search for documents using keyword matching.", + ): + """Initialize the text search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + config: Configuration for text query parameters. If not provided, + defaults will be used. + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + return_fields=return_fields, + ) + self._config = config or RedisTextQueryConfig() + self._filter_expression = filter_expression + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._config.num_results})." + ), + ), + }, + required=["query"], + ), + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> Dict[str, Any]: + """Execute the text search query. + + Args: + args: Arguments from the LLM, must include 'query'. + tool_context: The tool execution context. + + Returns: + A dictionary with status, count, and results. + """ + + async def build_query_fn( + query_text: str, args: Dict[str, Any] + ) -> TextQuery: + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( + text=query_text, + return_fields=self._return_fields, + filter_expression=self._filter_expression, + ) + # Allow LLM to override num_results + if "num_results" in args: + query_kwargs["num_results"] = args["num_results"] + return TextQuery(**query_kwargs) + + return await self._run_search(args, build_query_fn) diff --git a/src/google/adk_community/tools/redis/vector_search_tool.py b/src/google/adk_community/tools/redis/vector_search_tool.py new file mode 100644 index 0000000..a000dcf --- /dev/null +++ b/src/google/adk_community/tools/redis/vector_search_tool.py @@ -0,0 +1,155 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis vector similarity search tool using RedisVL.""" + +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +from typing import Union + +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import VectorizedSearchTool +from .config import RedisVectorQueryConfig + + +class RedisVectorSearchTool(VectorizedSearchTool): + """Vector similarity search tool using RedisVL. + + This tool performs K-nearest neighbor (KNN) vector similarity search + over a Redis index. It embeds the query text using the provided + vectorizer and finds the most similar documents. + + Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from redisvl.query.filter import Tag + from google.adk_community.tools.redis import ( + RedisVectorSearchTool, + RedisVectorQueryConfig, + ) + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + # Using config object (recommended) + config = RedisVectorQueryConfig( + num_results=5, + ef_runtime=100, # Higher = better recall + ) + tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + return_fields=["title", "content", "url"], + filter_expression=Tag("category") == "redis", + ) + + # Use with an agent + agent = Agent(model="gemini-2.5-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + config: Optional[RedisVectorQueryConfig] = None, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + name: str = "redis_vector_search", + description: str = "Search for semantically similar documents using vector similarity with Redis.", + ): + """Initialize the vector search tool. + + Args: + index: The RedisVL SearchIndex to query. + vectorizer: The vectorizer for embedding queries. + config: Configuration for query parameters. If None, uses defaults. + See RedisVectorQueryConfig for available options including + num_results, vector_field_name, dtype, and version-dependent + parameters like ef_runtime and hybrid_policy. + return_fields: Optional list of fields to return in results. + filter_expression: Optional RedisVL FilterExpression to narrow results. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + self._config = config or RedisVectorQueryConfig() + self._filter_expression = filter_expression + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._config.num_results})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> VectorQuery: + """Build a VectorQuery for KNN search. + + Args: + query_text: The original query text (unused for vector search). + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., num_results). + + Returns: + A VectorQuery configured for KNN search. + """ + # Allow runtime override of num_results + num_results = kwargs.get("num_results", self._config.num_results) + + # Get query kwargs from config, with runtime overrides + query_kwargs = self._config.to_query_kwargs( + vector=embedding, + filter_expression=self._filter_expression, + ) + query_kwargs["return_fields"] = self._return_fields + query_kwargs["num_results"] = num_results + + return VectorQuery(**query_kwargs) diff --git a/tests/unittests/memory/__init__.py b/tests/unittests/memory/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/memory/__init__.py +++ b/tests/unittests/memory/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/memory/test_open_memory_service.py b/tests/unittests/memory/test_open_memory_service.py index 74cb05e..718f6c6 100644 --- a/tests/unittests/memory/test_open_memory_service.py +++ b/tests/unittests/memory/test_open_memory_service.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import patch from google.adk.events.event import Event -from google.adk_community.memory.open_memory_service import ( - OpenMemoryService, - OpenMemoryServiceConfig, -) from google.adk.sessions.session import Session from google.genai import types import pytest +from google.adk_community.memory.open_memory_service import OpenMemoryService +from google.adk_community.memory.open_memory_service import OpenMemoryServiceConfig + MOCK_APP_NAME = 'test-app' MOCK_USER_ID = 'test-user' MOCK_SESSION_ID = 'session-1' @@ -39,7 +39,9 @@ invocation_id='inv-1', author='user', timestamp=12345, - content=types.Content(parts=[types.Part(text='Hello, I like Python.')]), + content=types.Content( + parts=[types.Part(text='Hello, I like Python.')] + ), ), Event( id='event-2', @@ -47,7 +49,9 @@ author='model', timestamp=12346, content=types.Content( - parts=[types.Part(text='Python is a great programming language.')] + parts=[ + types.Part(text='Python is a great programming language.') + ] ), ), # Empty event, should be ignored @@ -85,10 +89,12 @@ @pytest.fixture def mock_httpx_client(): """Mock httpx.AsyncClient for testing.""" - with patch('google.adk_community.memory.open_memory_service.httpx.AsyncClient') as mock_client_class: + with patch( + 'google.adk_community.memory.open_memory_service.httpx.AsyncClient' + ) as mock_client_class: mock_client = MagicMock() mock_response = MagicMock() - mock_response.json.return_value = {"matches": []} + mock_response.json.return_value = {'matches': []} mock_response.raise_for_status = MagicMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) @@ -107,14 +113,10 @@ def memory_service(mock_httpx_client): def memory_service_with_config(mock_httpx_client): """Create OpenMemoryService with custom config.""" config = OpenMemoryServiceConfig( - search_top_k=5, - user_content_salience=0.9, - model_content_salience=0.6 + search_top_k=5, user_content_salience=0.9, model_content_salience=0.6 ) return OpenMemoryService( - base_url='http://localhost:3000', - api_key='test-key', - config=config + base_url='http://localhost:3000', api_key='test-key', config=config ) @@ -139,7 +141,7 @@ def test_custom_config(self): user_content_salience=0.9, model_content_salience=0.75, default_salience=0.5, - enable_metadata_tags=False + enable_metadata_tags=False, ) assert config.search_top_k == 20 assert config.timeout == 10.0 @@ -158,18 +160,20 @@ def test_config_validation_search_top_k(self): def test_api_key_required(self): """Test that API key is required.""" - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000", api_key="") - - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000") + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000', api_key='') + + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000') class TestOpenMemoryService: """Tests for OpenMemoryService.""" @pytest.mark.asyncio - async def test_add_session_to_memory_success(self, memory_service, mock_httpx_client): + async def test_add_session_to_memory_success( + self, memory_service, mock_httpx_client + ): """Test successful addition of session memories.""" await memory_service.add_session_to_memory(MOCK_SESSION) @@ -220,9 +224,7 @@ async def test_add_session_uses_config_salience( assert request_data['salience'] == 0.6 # Custom model salience @pytest.mark.asyncio - async def test_add_session_without_metadata_tags( - self, mock_httpx_client - ): + async def test_add_session_without_metadata_tags(self, mock_httpx_client): """Test adding memories without metadata tags.""" config = OpenMemoryServiceConfig(enable_metadata_tags=False) memory_service = OpenMemoryService( @@ -236,7 +238,9 @@ async def test_add_session_without_metadata_tags( assert request_data.get('tags', []) == [] @pytest.mark.asyncio - async def test_add_session_error_handling(self, memory_service, mock_httpx_client): + async def test_add_session_error_handling( + self, memory_service, mock_httpx_client + ): """Test error handling during memory addition.""" mock_httpx_client.post.side_effect = Exception('API Error') @@ -254,20 +258,23 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): mock_response.json.return_value = { 'matches': [ { - 'content': '[Author: user, Time: 2025-01-01T00:00:00] Python is great', + 'content': ( + '[Author: user, Time: 2025-01-01T00:00:00] Python is great' + ), }, { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like' + ' programming' + ), + }, ] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='Python programming' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Python programming' ) # Verify API call @@ -276,7 +283,7 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): assert request_data['query'] == 'Python programming' assert request_data['k'] == 10 assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Verify results (content should be cleaned of metadata prefix) assert len(result.memories) == 2 @@ -293,26 +300,24 @@ async def test_search_memory_applies_filters( # Mock response - server-side filtering ensures only matching results mock_response = MagicMock() mock_response.json.return_value = { - 'matches': [ - { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } - ] + 'matches': [{ + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like programming' + ), + }] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Verify filters were passed correctly call_args = mock_httpx_client.post.call_args request_data = call_args.kwargs['json'] assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Should return filtered results assert len(result.memories) == 1 @@ -329,9 +334,7 @@ async def test_search_memory_respects_top_k( mock_httpx_client.post = AsyncMock(return_value=mock_response) await memory_service_with_config.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) call_args = mock_httpx_client.post.call_args @@ -346,11 +349,8 @@ async def test_search_memory_error_handling( mock_httpx_client.post.side_effect = Exception('API Error') result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Should return empty results on error assert len(result.memories) == 0 - diff --git a/tests/unittests/sessions/test_redis_session_service.py b/tests/unittests/sessions/test_redis_session_service.py index dad0867..2354561 100644 --- a/tests/unittests/sessions/test_redis_session_service.py +++ b/tests/unittests/sessions/test_redis_session_service.py @@ -12,549 +12,568 @@ # See the License for the specific language governing permissions and # limitations under the License. -import orjson -from datetime import datetime, timezone -import pytest -import pytest_asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime +from datetime import timezone +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.sessions.base_session_service import GetSessionConfig -from google.adk_community.sessions.redis_session_service import RedisSessionService from google.genai import types +import orjson +import pytest +import pytest_asyncio +from google.adk_community.sessions.redis_session_service import RedisSessionService -class TestRedisSessionService: - """Test cases for RedisSessionService.""" - - @pytest_asyncio.fixture - async def redis_service(self): - """Create a Redis session service for testing.""" - with patch("redis.asyncio.Redis") as mock_redis: - mock_client = AsyncMock() - mock_redis.return_value = mock_client - service = RedisSessionService() - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_service(self): - """Create a Redis cluster session service for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://redis-node1:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_uri_service(self): - """Create a Redis cluster session service using URI for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://node1:6379,node2:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - def _setup_redis_mocks(self, redis_service, sessions_data=None): - """Helper to set up Redis mocks for the new storage strategy.""" - if sessions_data is None: - sessions_data = {} - - session_ids = list(sessions_data.keys()) - redis_service.cache.smembers = AsyncMock( - return_value={sid.encode() for sid in session_ids} - ) - - # Mock the new cluster-aware pipeline approach - session_values = [ - orjson.dumps(sessions_data[sid]) if sid in sessions_data else None - for sid in session_ids - ] - # For backward compatibility with mget approach (still used in some tests) - redis_service.cache.mget = AsyncMock(return_value=session_values) - - # Mock pipeline for the new cluster approach - if session_ids: - # Group sessions as the actual implementation does - results_per_group = [] - for i in range(len(session_ids)): - results_per_group.append([session_values[i]]) - - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(side_effect=results_per_group) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - else: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.srem = AsyncMock() - redis_service.cache.get = AsyncMock(return_value=None) # Default to no session - - # Additional pipeline operations for create/update operations - if not session_ids: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining - mock_pipe.sadd = MagicMock(return_value=mock_pipe) - mock_pipe.expire = MagicMock(return_value=mock_pipe) - mock_pipe.delete = MagicMock(return_value=mock_pipe) - mock_pipe.srem = MagicMock(return_value=mock_pipe) - mock_pipe.hset = MagicMock(return_value=mock_pipe) - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.hgetall = AsyncMock(return_value={}) - redis_service.cache.hset = AsyncMock() - - @pytest.mark.asyncio - async def test_get_empty_session(self, redis_service): - """Test getting a non-existent session.""" - self._setup_redis_mocks(redis_service) - - session = await redis_service.get_session( - app_name="test_app", user_id="test_user", session_id="nonexistent" - ) - - assert session is None - - @pytest.mark.asyncio - async def test_create_get_session(self, redis_service): - """Test session creation and retrieval.""" - app_name = "test_app" - user_id = "test_user" - state = {"key": "value"} - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Allow tiny float/clock rounding differences (~1ms) - assert ( - session.last_update_time - <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 - ) - - # Mock individual session retrieval - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_create_and_list_sessions(self, redis_service): - """Test creating multiple sessions and listing them. - - list_sessions() is expected to return lightweight session summaries, - i.e., with events and state stripped for performance. - """ - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session_ids = ["session" + str(i) for i in range(3)] - sessions_data = {} - - for i, session_id in enumerate(session_ids): - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"key": "value" + session_id}, - ) - # Add at least one event to ensure list_sessions actually strips them. - session.events.append(Event(author="user", timestamp=float(i + 1))) - sessions_data[session_id] = session.model_dump() - - # Now mock Redis to return those sessions (with events present in storage) - self._setup_redis_mocks(redis_service, sessions_data) - - list_sessions_response = await redis_service.list_sessions( - app_name=app_name, user_id=user_id - ) - sessions = list_sessions_response.sessions - - assert len(sessions) == len(session_ids) - returned_session_ids = {s.id for s in sessions} - assert returned_session_ids == set(session_ids) - - for s in sessions: - # list_sessions returns summaries: events and state removed for perf. - assert len(s.events) == 0 - assert s.state == {} - - @pytest.mark.asyncio - async def test_session_state_management(self, redis_service): - """Test session state management with app, user, and temp state.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"initial_key": "initial_value"}, - ) - - event = Event( - invocation_id="invocation", - author="user", - content=types.Content(role="user", parts=[types.Part(text="text")]), - actions=EventActions( - state_delta={ - "app:key": "app_value", - "user:key1": "user_value", - "temp:key": "temp_value", - "initial_key": "updated_value", - } +class TestRedisSessionService: + """Test cases for RedisSessionService.""" + + @pytest_asyncio.fixture + async def redis_service(self): + """Create a Redis session service for testing.""" + with patch("redis.asyncio.Redis") as mock_redis: + mock_client = AsyncMock() + mock_redis.return_value = mock_client + service = RedisSessionService() + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_service(self): + """Create a Redis cluster session service for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://redis-node1:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_uri_service(self): + """Create a Redis cluster session service using URI for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://node1:6379,node2:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + def _setup_redis_mocks(self, redis_service, sessions_data=None): + """Helper to set up Redis mocks for the new storage strategy.""" + if sessions_data is None: + sessions_data = {} + + session_ids = list(sessions_data.keys()) + redis_service.cache.smembers = AsyncMock( + return_value={sid.encode() for sid in session_ids} + ) + + # Mock the new cluster-aware pipeline approach + session_values = [ + orjson.dumps(sessions_data[sid]) if sid in sessions_data else None + for sid in session_ids + ] + + # For backward compatibility with mget approach (still used in some tests) + redis_service.cache.mget = AsyncMock(return_value=session_values) + + # Mock pipeline for the new cluster approach + if session_ids: + # Group sessions as the actual implementation does + results_per_group = [] + for i in range(len(session_ids)): + results_per_group.append([session_values[i]]) + + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(side_effect=results_per_group) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + else: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.srem = AsyncMock() + redis_service.cache.get = AsyncMock( + return_value=None + ) # Default to no session + + # Additional pipeline operations for create/update operations + if not session_ids: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining + mock_pipe.sadd = MagicMock(return_value=mock_pipe) + mock_pipe.expire = MagicMock(return_value=mock_pipe) + mock_pipe.delete = MagicMock(return_value=mock_pipe) + mock_pipe.srem = MagicMock(return_value=mock_pipe) + mock_pipe.hset = MagicMock(return_value=mock_pipe) + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.hgetall = AsyncMock(return_value={}) + redis_service.cache.hset = AsyncMock() + + @pytest.mark.asyncio + async def test_get_empty_session(self, redis_service): + """Test getting a non-existent session.""" + self._setup_redis_mocks(redis_service) + + session = await redis_service.get_session( + app_name="test_app", user_id="test_user", session_id="nonexistent" + ) + + assert session is None + + @pytest.mark.asyncio + async def test_create_get_session(self, redis_service): + """Test session creation and retrieval.""" + app_name = "test_app" + user_id = "test_user" + state = {"key": "value"} + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Allow tiny float/clock rounding differences (~1ms) + assert ( + session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 + ) + + # Mock individual session retrieval + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_create_and_list_sessions(self, redis_service): + """Test creating multiple sessions and listing them. + + list_sessions() is expected to return lightweight session summaries, + i.e., with events and state stripped for performance. + """ + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session_ids = ["session" + str(i) for i in range(3)] + sessions_data = {} + + for i, session_id in enumerate(session_ids): + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"key": "value" + session_id}, + ) + # Add at least one event to ensure list_sessions actually strips them. + session.events.append(Event(author="user", timestamp=float(i + 1))) + sessions_data[session_id] = session.model_dump() + + # Now mock Redis to return those sessions (with events present in storage) + self._setup_redis_mocks(redis_service, sessions_data) + + list_sessions_response = await redis_service.list_sessions( + app_name=app_name, user_id=user_id + ) + sessions = list_sessions_response.sessions + + assert len(sessions) == len(session_ids) + returned_session_ids = {s.id for s in sessions} + assert returned_session_ids == set(session_ids) + + for s in sessions: + # list_sessions returns summaries: events and state removed for perf. + assert len(s.events) == 0 + assert s.state == {} + + @pytest.mark.asyncio + async def test_session_state_management(self, redis_service): + """Test session state management with app, user, and temp state.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"initial_key": "initial_value"}, + ) + + event = Event( + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="text")]), + actions=EventActions( + state_delta={ + "app:key": "app_value", + "user:key1": "user_value", + "temp:key": "temp_value", + "initial_key": "updated_value", + } + ), + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + assert session.state.get("app:key") == "app_value" + assert session.state.get("user:key1") == "user_value" + assert session.state.get("initial_key") == "updated_value" + assert session.state.get("temp:key") is None # Temp state filtered + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.hset.assert_any_call( + "app:test_app", "key", orjson.dumps("app_value") + ) + pipe_mock.hset.assert_any_call( + "user:test_app:test_user", "key1", orjson.dumps("user_value") + ) + + @pytest.mark.asyncio + async def test_append_event_with_bytes(self, redis_service): + """Test appending events with binary content and serialization roundtrip.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + test_content = types.Content( + role="user", + parts=[ + types.Part.from_bytes( + data=b"test_image_data", mime_type="image/png" ), - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - assert session.state.get("app:key") == "app_value" - assert session.state.get("user:key1") == "user_value" - assert session.state.get("initial_key") == "updated_value" - assert session.state.get("temp:key") is None # Temp state filtered - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.hset.assert_any_call("app:test_app", "key", orjson.dumps("app_value")) - pipe_mock.hset.assert_any_call( - "user:test_app:test_user", "key1", orjson.dumps("user_value") - ) - - @pytest.mark.asyncio - async def test_append_event_with_bytes(self, redis_service): - """Test appending events with binary content and serialization roundtrip.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - test_content = types.Content( - role="user", - parts=[ - types.Part.from_bytes( - data=b"test_image_data", mime_type="image/png" - ), - ], - ) - test_grounding_metadata = types.GroundingMetadata( - search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") - ) - event = Event( - invocation_id="invocation", - author="user", - content=test_content, - grounding_metadata=test_grounding_metadata, - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - # Verify the event was appended to in-memory session - assert len(session.events) == 1 - assert session.events[0].content == test_content - assert session.events[0].grounding_metadata == test_grounding_metadata - - # Test serialization/deserialization roundtrip to ensure binary data is preserved - # Simulate what happens when session is stored and retrieved from Redis - serialized_session = session.model_dump_json() - - redis_service.cache.get = AsyncMock(return_value=serialized_session.encode()) - - retrieved_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert retrieved_session is not None - assert len(retrieved_session.events) == 1 - - # Verify the binary content was preserved through serialization - retrieved_event = retrieved_session.events[0] - assert retrieved_event.content.parts[0].inline_data.data == b"test_image_data" - assert ( - retrieved_event.content.parts[0].inline_data.mime_type - == "image/png" - ) - assert ( - retrieved_event.grounding_metadata.search_entry_point.sdk_blob - == b"test_sdk_blob" - ) - - @pytest.mark.asyncio - async def test_get_session_with_config(self, redis_service): - """Test getting session with configuration filters.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - # Add multiple events with different timestamps - num_test_events = 5 - for i in range(1, num_test_events + 1): - event = Event(author="user", timestamp=float(i)) - session.events.append(event) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - # Test num_recent_events filter - config = GetSessionConfig(num_recent_events=3) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 - assert filtered_session.events[0].timestamp == 3.0 # Last 3 events - - # Test after_timestamp filter - config = GetSessionConfig(after_timestamp=3.0) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 # Events 3, 4, 5 - assert filtered_session.events[0].timestamp == 3.0 - - @pytest.mark.asyncio - async def test_delete_session(self, redis_service): - """Test session deletion.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) # Empty sessions - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - redis_service.cache.pipeline.reset_mock() - self._setup_redis_mocks(redis_service) - - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - @pytest.mark.asyncio - async def test_cluster_health_check(self, redis_cluster_service): - """Test health check for Redis cluster.""" - redis_cluster_service.cache.ping = AsyncMock(return_value=True) - - result = await redis_cluster_service.health_check() - assert result is True - redis_cluster_service.cache.ping.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_health_check_failure(self, redis_cluster_service): - """Test health check failure for Redis cluster.""" - from redis import RedisError - - redis_cluster_service.cache.ping = AsyncMock( - side_effect=RedisError("Connection failed") - ) - - result = await redis_cluster_service.health_check() - assert result is False - - @pytest.mark.asyncio - async def test_cluster_create_and_get_session(self, redis_cluster_service): - """Test session creation and retrieval in cluster mode.""" - app_name = "cluster_test_app" - user_id = "cluster_test_user" - state = {"cluster_key": "cluster_value"} - - self._setup_redis_mocks(redis_cluster_service) - - session = await redis_cluster_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Mock individual session retrieval - redis_cluster_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_cluster_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_cluster_uri_initialization(self, redis_cluster_uri_service): - """Test Redis cluster initialization with URI.""" - assert redis_cluster_uri_service.cache is not None - - @pytest.mark.asyncio - async def test_cluster_error_handling(self, redis_cluster_service): - """Test error handling in cluster operations.""" - from redis import RedisError - - app_name = "test_app" - user_id = "test_user" - - # Mock Redis error during session loading - redis_cluster_service.cache.smembers = AsyncMock( - side_effect=RedisError("Cluster error") - ) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - assert len(sessions_response.sessions) == 0 - - @pytest.mark.asyncio - async def test_cluster_connection_validation(self): - """Test cluster connection validation during initialization.""" - cluster_uri = "redis://redis-node1:6379" - - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - - service = RedisSessionService(cluster_uri=cluster_uri) - assert service.cache is not None - mock_redis_cluster.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): - """Test session cleanup when corrupted data is found in cluster.""" - app_name = "test_app" - user_id = "test_user" - - # Setup mock with corrupted session data - valid_session_data = { - "app_name": "test_app", - "user_id": "test_user", - "id": "session1", - "state": {}, - "events": [], - "last_update_time": 1234567890, - } - redis_cluster_service.cache.smembers = AsyncMock( - return_value={b"session1", b"session2"} - ) - - # Mock the pipeline for cluster approach - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock( - side_effect=[ - [orjson.dumps(valid_session_data)], # session1 result - [None], # session2 result (missing) - ] - ) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_cluster_service.cache.pipeline = MagicMock( - return_value=mock_context_manager - ) - redis_cluster_service.cache.srem = AsyncMock() - redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - redis_cluster_service.cache.srem.assert_called() - assert len(sessions_response.sessions) == 1 - - @pytest.mark.asyncio - async def test_decode_responses_handling(self, redis_service): - """Test proper handling of decode_responses setting.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - # Test with bytes response (decode_responses=False) - session_data = ( - '{"app_name": "test_app", "user_id": "test_user", "id": "test_session", ' - '"state": {}, "events": [], "last_update_time": 1234567890}' - ) - redis_service.cache.get = AsyncMock(return_value=session_data.encode()) - redis_service.cache.hgetall = AsyncMock(return_value={}) - - session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - assert session is not None - assert session.app_name == app_name - assert session.user_id == user_id + ], + ) + test_grounding_metadata = types.GroundingMetadata( + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") + ) + event = Event( + invocation_id="invocation", + author="user", + content=test_content, + grounding_metadata=test_grounding_metadata, + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + # Verify the event was appended to in-memory session + assert len(session.events) == 1 + assert session.events[0].content == test_content + assert session.events[0].grounding_metadata == test_grounding_metadata + + # Test serialization/deserialization roundtrip to ensure binary data is preserved + # Simulate what happens when session is stored and retrieved from Redis + serialized_session = session.model_dump_json() + + redis_service.cache.get = AsyncMock( + return_value=serialized_session.encode() + ) + + retrieved_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) == 1 + + # Verify the binary content was preserved through serialization + retrieved_event = retrieved_session.events[0] + assert ( + retrieved_event.content.parts[0].inline_data.data == b"test_image_data" + ) + assert retrieved_event.content.parts[0].inline_data.mime_type == "image/png" + assert ( + retrieved_event.grounding_metadata.search_entry_point.sdk_blob + == b"test_sdk_blob" + ) + + @pytest.mark.asyncio + async def test_get_session_with_config(self, redis_service): + """Test getting session with configuration filters.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Add multiple events with different timestamps + num_test_events = 5 + for i in range(1, num_test_events + 1): + event = Event(author="user", timestamp=float(i)) + session.events.append(event) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + # Test num_recent_events filter + config = GetSessionConfig(num_recent_events=3) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 + assert filtered_session.events[0].timestamp == 3.0 # Last 3 events + + # Test after_timestamp filter + config = GetSessionConfig(after_timestamp=3.0) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 # Events 3, 4, 5 + assert filtered_session.events[0].timestamp == 3.0 + + @pytest.mark.asyncio + async def test_delete_session(self, redis_service): + """Test session deletion.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) # Empty sessions + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + redis_service.cache.pipeline.reset_mock() + self._setup_redis_mocks(redis_service) + + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + @pytest.mark.asyncio + async def test_cluster_health_check(self, redis_cluster_service): + """Test health check for Redis cluster.""" + redis_cluster_service.cache.ping = AsyncMock(return_value=True) + + result = await redis_cluster_service.health_check() + assert result is True + redis_cluster_service.cache.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_health_check_failure(self, redis_cluster_service): + """Test health check failure for Redis cluster.""" + from redis import RedisError + + redis_cluster_service.cache.ping = AsyncMock( + side_effect=RedisError("Connection failed") + ) + + result = await redis_cluster_service.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_cluster_create_and_get_session(self, redis_cluster_service): + """Test session creation and retrieval in cluster mode.""" + app_name = "cluster_test_app" + user_id = "cluster_test_user" + state = {"cluster_key": "cluster_value"} + + self._setup_redis_mocks(redis_cluster_service) + + session = await redis_cluster_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Mock individual session retrieval + redis_cluster_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_cluster_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_cluster_uri_initialization(self, redis_cluster_uri_service): + """Test Redis cluster initialization with URI.""" + assert redis_cluster_uri_service.cache is not None + + @pytest.mark.asyncio + async def test_cluster_error_handling(self, redis_cluster_service): + """Test error handling in cluster operations.""" + from redis import RedisError + + app_name = "test_app" + user_id = "test_user" + + # Mock Redis error during session loading + redis_cluster_service.cache.smembers = AsyncMock( + side_effect=RedisError("Cluster error") + ) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + assert len(sessions_response.sessions) == 0 + + @pytest.mark.asyncio + async def test_cluster_connection_validation(self): + """Test cluster connection validation during initialization.""" + cluster_uri = "redis://redis-node1:6379" + + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + + service = RedisSessionService(cluster_uri=cluster_uri) + assert service.cache is not None + mock_redis_cluster.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): + """Test session cleanup when corrupted data is found in cluster.""" + app_name = "test_app" + user_id = "test_user" + + # Setup mock with corrupted session data + valid_session_data = { + "app_name": "test_app", + "user_id": "test_user", + "id": "session1", + "state": {}, + "events": [], + "last_update_time": 1234567890, + } + redis_cluster_service.cache.smembers = AsyncMock( + return_value={b"session1", b"session2"} + ) + + # Mock the pipeline for cluster approach + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock( + side_effect=[ + [orjson.dumps(valid_session_data)], # session1 result + [None], # session2 result (missing) + ] + ) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_cluster_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + redis_cluster_service.cache.srem = AsyncMock() + redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + redis_cluster_service.cache.srem.assert_called() + assert len(sessions_response.sessions) == 1 + + @pytest.mark.asyncio + async def test_decode_responses_handling(self, redis_service): + """Test proper handling of decode_responses setting.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Test with bytes response (decode_responses=False) + session_data = ( + '{"app_name": "test_app", "user_id": "test_user", "id": "test_session",' + ' "state": {}, "events": [], "last_update_time": 1234567890}' + ) + redis_service.cache.get = AsyncMock(return_value=session_data.encode()) + redis_service.cache.hgetall = AsyncMock(return_value={}) + + session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.app_name == app_name + assert session.user_id == user_id diff --git a/tests/unittests/tools/__init__.py b/tests/unittests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/redis/__init__.py b/tests/unittests/tools/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py new file mode 100644 index 0000000..606b162 --- /dev/null +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -0,0 +1,367 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisHybridSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch +import warnings + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisAggregatedHybridQueryConfig +from google.adk_community.tools.redis import RedisHybridQueryConfig +from google.adk_community.tools.redis import RedisHybridSearchTool +from google.adk_community.tools.redis.hybrid_search_tool import _get_redis_server_version +from google.adk_community.tools.redis.hybrid_search_tool import _get_redisvl_version +from google.adk_community.tools.redis.hybrid_search_tool import _supports_native_hybrid + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_redis_client(): + """Mock Redis client with info method.""" + client = MagicMock() + client.info = MagicMock(return_value={"redis_version": "8.4.0"}) + return client + + +@pytest.fixture +def mock_index(mock_redis_client): + """Mock RedisVL SearchIndex with Redis client.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + ) + # Mock _redis_client property for sync index + type(index)._redis_client = property(lambda self: mock_redis_client) + return index + + +def _native_hybrid_available(): + """Check if native HybridQuery dependencies are available.""" + try: + from redis.commands.search.hybrid_query import CombineResultsMethod + from redis.commands.search.hybrid_query import HybridPostProcessingConfig + + return True + except (ImportError, ModuleNotFoundError): + return False + + +class TestVersionDetection: + """Tests for version detection functions.""" + + def test_get_redisvl_version(self): + """Test version string retrieval.""" + version = _get_redisvl_version() + assert isinstance(version, str) + # Should be a valid version string like "0.13.0" or "0.0.0" + assert len(version.split(".")) >= 2 + + def test_get_redis_server_version(self, mock_index): + """Test Redis server version retrieval.""" + version = _get_redis_server_version(mock_index) + assert version == "8.4.0" + + def test_get_redis_server_version_client_none(self): + """Test Redis server version when client is None.""" + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: None) + version = _get_redis_server_version(index) + assert version == "0.0.0" + + def test_get_redis_server_version_exception(self, mock_index): + """Test Redis server version when info() raises exception.""" + mock_index._redis_client.info.side_effect = Exception("Connection error") + version = _get_redis_server_version(mock_index) + assert version == "0.0.0" + + def test_supports_native_hybrid_both_versions_ok(self, mock_index): + """Test native hybrid support when both versions meet requirements.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.13.0", + ): + result = _supports_native_hybrid(mock_index) + assert result is True + + def test_supports_native_hybrid_redisvl_too_old(self, mock_index): + """Test native hybrid not supported when redisvl version is too old.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.12.0", + ): + result = _supports_native_hybrid(mock_index) + assert result is False + + def test_supports_native_hybrid_redis_server_too_old(self): + """Test native hybrid not supported when Redis server version is too old.""" + # Create index with old Redis server version + mock_client = MagicMock() + mock_client.info = MagicMock(return_value={"redis_version": "7.2.0"}) + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: mock_client) + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.13.0", + ): + result = _supports_native_hybrid(index) + assert result is False + + def test_supports_native_hybrid_both_versions_too_old(self): + """Test native hybrid not supported when both versions are too old.""" + # Create index with old Redis server version + mock_client = MagicMock() + mock_client.info = MagicMock(return_value={"redis_version": "7.2.0"}) + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: mock_client) + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.12.0", + ): + result = _supports_native_hybrid(index) + assert result is False + + +class TestRedisHybridQueryConfig: + """Tests for RedisHybridQueryConfig (native mode).""" + + def test_default_values(self): + """Test default config values.""" + config = RedisHybridQueryConfig() + assert config.text_field_name == "content" + assert config.vector_field_name == "embedding" + assert config.text_scorer == "BM25STD" + assert config.combination_method is None + assert config.linear_alpha == 0.3 + assert config.rrf_window == 20 + assert config.rrf_constant == 60 + assert config.num_results == 10 + assert config.dtype == "float32" + assert config.stopwords == "english" + + def test_to_query_kwargs(self): + """Test conversion to query kwargs.""" + config = RedisHybridQueryConfig( + text_field_name="content", + combination_method="LINEAR", + linear_alpha=0.7, + ) + kwargs = config.to_query_kwargs( + text="test query", + vector=[0.1] * 384, + return_fields=["title"], + ) + assert kwargs["text"] == "test query" + assert kwargs["text_field_name"] == "content" + assert kwargs["combination_method"] == "LINEAR" + assert kwargs["linear_alpha"] == 0.7 + assert kwargs["return_fields"] == ["title"] + + +class TestRedisAggregatedHybridQueryConfig: + """Tests for RedisAggregatedHybridQueryConfig (legacy mode).""" + + def test_default_values(self): + """Test default config values.""" + config = RedisAggregatedHybridQueryConfig() + assert config.text_field_name == "content" + assert config.vector_field_name == "embedding" + assert config.text_scorer == "BM25STD" + assert config.alpha == 0.7 + assert config.num_results == 10 + assert config.dtype == "float32" + assert config.dialect == 2 + + def test_to_query_kwargs(self): + """Test conversion to query kwargs.""" + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.8, + ) + kwargs = config.to_query_kwargs( + text="test query", + vector=[0.1] * 384, + return_fields=["title"], + ) + assert kwargs["text"] == "test query" + assert kwargs["text_field_name"] == "content" + assert kwargs["alpha"] == 0.8 + assert kwargs["return_fields"] == ["title"] + assert "linear_alpha" not in kwargs + assert "combination_method" not in kwargs + + +class TestRedisHybridSearchToolInit: + """Tests for RedisHybridSearchTool initialization.""" + + def test_auto_detect_config_native(self, mock_index, mock_vectorizer): + """Test auto-detection uses native config when supported.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert isinstance(tool._config, RedisHybridQueryConfig) + assert tool._use_native is True + + def test_auto_detect_config_aggregate(self, mock_index, mock_vectorizer): + """Test auto-detection uses aggregate config when native not supported.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert isinstance(tool._config, RedisAggregatedHybridQueryConfig) + assert tool._use_native is False + + def test_native_config_on_old_version_raises( + self, mock_index, mock_vectorizer + ): + """Test that using native config on old version raises ValueError.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + with pytest.raises(ValueError, match="RedisHybridQueryConfig requires"): + RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(), + ) + + def test_aggregate_config_on_new_version_warns( + self, mock_index, mock_vectorizer + ): + """Test that using aggregate config on new version emits deprecation warning.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisAggregatedHybridQueryConfig(), + ) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "deprecated" in str(w[0].message).lower() + assert tool._use_native is False + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_hybrid", + description="Custom hybrid search", + ) + assert tool.name == "custom_hybrid" + assert tool.description == "Custom hybrid search" + + +class TestRedisHybridSearchToolBuildQuery: + """Tests for _build_query method.""" + + @pytest.mark.skipif( + not _native_hybrid_available(), + reason="HybridQuery requires redis-py>=7.1.0 and Redis>=8.4.0", + ) + def test_build_query_native(self, mock_index, mock_vectorizer): + """Test query building with native HybridQuery.""" + from redisvl.query import HybridQuery + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(num_results=5), + ) + embedding = [0.1] * 384 + query = tool._build_query("test query", embedding) + assert isinstance(query, HybridQuery) + + def test_build_query_aggregate(self, mock_index, mock_vectorizer): + """Test query building with AggregateHybridQuery.""" + from redisvl.query import AggregateHybridQuery + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisAggregatedHybridQueryConfig(num_results=5), + ) + embedding = [0.1] * 384 + query = tool._build_query("test query", embedding) + assert isinstance(query, AggregateHybridQuery) + + +class TestRedisHybridSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, mock_index, mock_vectorizer): + """Test function declaration generation.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(num_results=5), + ) + declaration = tool._get_declaration() + + assert declaration.name == "redis_hybrid_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required diff --git a/tests/unittests/tools/redis/test_range_search_tool.py b/tests/unittests/tools/redis/test_range_search_tool.py new file mode 100644 index 0000000..a70f683 --- /dev/null +++ b/tests/unittests/tools/redis/test_range_search_tool.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisRangeSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.query import VectorRangeQuery +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisRangeQueryConfig +from google.adk_community.tools.redis import RedisRangeSearchTool + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[{ + "title": "Test Doc", + "content": "Test content", + "vector_distance": 0.1, + }] + ) + return index + + +@pytest.fixture +def range_search_tool(mock_index, mock_vectorizer): + """Create RedisRangeSearchTool instance for testing.""" + config = RedisRangeQueryConfig( + distance_threshold=0.3, + num_results=5, + ) + return RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "content"], + ) + + +class TestRedisRangeSearchToolInit: + """Tests for RedisRangeSearchTool initialization.""" + + def test_default_parameters(self, mock_index, mock_vectorizer): + """Test default parameter values with default config.""" + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + # Config defaults + assert tool._config.vector_field_name == "embedding" + assert tool._config.distance_threshold == 0.2 + assert tool._config.num_results == 10 + assert tool._config.dtype == "float32" + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.normalize_vector_distance is False + assert tool._config.sort_by is None + assert tool._config.epsilon is None + # Tool-level defaults + assert tool._filter_expression is None + + def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): + """Test custom parameter values via config object.""" + config = RedisRangeQueryConfig( + vector_field_name="vec", + distance_threshold=0.5, + num_results=20, + dtype="float64", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + epsilon=0.01, + ) + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "url"], + ) + assert tool._config.vector_field_name == "vec" + assert tool._config.distance_threshold == 0.5 + assert tool._config.num_results == 20 + assert tool._return_fields == ["title", "url"] + assert tool._config.dtype == "float64" + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.normalize_vector_distance is True + assert tool._config.epsilon == 0.01 + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_range", + description="Custom range search", + ) + assert tool.name == "custom_range" + assert tool.description == "Custom range search" + + +class TestRedisRangeSearchToolBuildQuery: + """Tests for _build_query method.""" + + def test_build_query_basic(self, range_search_tool): + """Test basic query building.""" + embedding = [0.1] * 384 + query = range_search_tool._build_query("test query", embedding) + + assert isinstance(query, VectorRangeQuery) + + def test_build_query_with_threshold_override(self, range_search_tool): + """Test query building with distance_threshold override.""" + embedding = [0.1] * 384 + query = range_search_tool._build_query( + "test query", embedding, distance_threshold=0.8 + ) + + assert query._distance_threshold == 0.8 + + +class TestRedisRangeSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, range_search_tool): + """Test function declaration generation.""" + declaration = range_search_tool._get_declaration() + + assert declaration.name == "redis_range_search" + assert "query" in declaration.parameters.properties + assert "distance_threshold" in declaration.parameters.properties + assert "query" in declaration.parameters.required diff --git a/tests/unittests/tools/redis/test_text_search_tool.py b/tests/unittests/tools/redis/test_text_search_tool.py new file mode 100644 index 0000000..1e44c12 --- /dev/null +++ b/tests/unittests/tools/redis/test_text_search_tool.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisTextSearchTool.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex + +from google.adk_community.tools.redis import RedisTextQueryConfig +from google.adk_community.tools.redis import RedisTextSearchTool + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + ) + return index + + +@pytest.fixture +def text_search_tool(mock_index): + """Create RedisTextSearchTool instance for testing.""" + config = RedisTextQueryConfig( + text_field_name="content", + num_results=5, + stopwords=None, # Avoid nltk dependency in tests + ) + return RedisTextSearchTool( + index=mock_index, + config=config, + return_fields=["title", "content"], + ) + + +class TestRedisTextSearchToolInit: + """Tests for RedisTextSearchTool initialization.""" + + def test_default_parameters(self, mock_index): + """Test default parameter values via config.""" + tool = RedisTextSearchTool(index=mock_index) + assert tool._config.text_field_name == "content" + assert tool._config.text_scorer == "BM25STD" + assert tool._config.num_results == 10 + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.stopwords == "english" + assert tool._filter_expression is None + assert tool._config.sort_by is None + assert tool._return_fields is None + + def test_custom_parameters_via_config(self, mock_index): + """Test custom parameter values via config object.""" + config = RedisTextQueryConfig( + text_field_name="description", + text_scorer="TFIDF", + num_results=20, + return_score=False, + dialect=3, + in_order=True, + stopwords={"the", "a", "an"}, + ) + tool = RedisTextSearchTool( + index=mock_index, + config=config, + return_fields=["title", "url"], + ) + assert tool._config.text_field_name == "description" + assert tool._config.text_scorer == "TFIDF" + assert tool._config.num_results == 20 + assert tool._return_fields == ["title", "url"] + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.stopwords == {"the", "a", "an"} + + def test_custom_name_and_description(self, mock_index): + """Test custom tool name and description.""" + tool = RedisTextSearchTool( + index=mock_index, + name="custom_text", + description="Custom text search", + ) + assert tool.name == "custom_text" + assert tool.description == "Custom text search" + + def test_no_vectorizer_required(self, mock_index): + """Test that TextSearchTool does not require a vectorizer.""" + # This should work without any vectorizer + tool = RedisTextSearchTool(index=mock_index) + assert not hasattr(tool, "_vectorizer") + + +class TestRedisTextSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, text_search_tool): + """Test function declaration generation.""" + declaration = text_search_tool._get_declaration() + + assert declaration.name == "redis_text_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required + + +class TestRedisTextSearchToolRunAsync: + """Tests for run_async method.""" + + @pytest.mark.asyncio + @patch("google.adk_community.tools.redis.text_search_tool.asyncio.to_thread") + async def test_run_async_success( + self, mock_to_thread, text_search_tool, mock_index + ): + """Test successful search execution.""" + mock_to_thread.return_value = [ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + mock_context = MagicMock() + result = await text_search_tool.run_async( + args={"query": "test query"}, + tool_context=mock_context, + ) + + assert result["status"] == "success" + assert result["count"] == 1 + assert len(result["results"]) == 1 + mock_to_thread.assert_called_once() + + @pytest.mark.asyncio + async def test_run_async_empty_query(self, text_search_tool): + """Test error handling for empty query.""" + mock_context = MagicMock() + result = await text_search_tool.run_async( + args={"query": ""}, + tool_context=mock_context, + ) + + assert result["status"] == "error" + assert "required" in result["error"].lower() + + @pytest.mark.asyncio + @patch("google.adk_community.tools.redis.text_search_tool.asyncio.to_thread") + async def test_run_async_with_num_results( + self, mock_to_thread, text_search_tool, mock_index + ): + """Test search with custom num_results.""" + mock_to_thread.return_value = [ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + mock_context = MagicMock() + await text_search_tool.run_async( + args={"query": "test", "num_results": 15}, + tool_context=mock_context, + ) + + # Verify to_thread was called + mock_to_thread.assert_called_once() diff --git a/tests/unittests/tools/redis/test_vector_search_tool.py b/tests/unittests/tools/redis/test_vector_search_tool.py new file mode 100644 index 0000000..a71e9fd --- /dev/null +++ b/tests/unittests/tools/redis/test_vector_search_tool.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisVectorSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisVectorQueryConfig +from google.adk_community.tools.redis import RedisVectorSearchTool + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.search = MagicMock( + return_value=[{ + "title": "Test Doc", + "content": "Test content", + "vector_distance": 0.1, + }] + ) + return index + + +@pytest.fixture +def vector_search_tool(mock_index, mock_vectorizer): + """Create RedisVectorSearchTool instance for testing.""" + config = RedisVectorQueryConfig(num_results=5) + return RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "content"], + ) + + +class TestRedisVectorSearchToolInit: + """Tests for RedisVectorSearchTool initialization.""" + + def test_default_parameters(self, mock_index, mock_vectorizer): + """Test default parameter values with default config.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + # Config defaults + assert tool._config.vector_field_name == "embedding" + assert tool._config.num_results == 10 + assert tool._config.dtype == "float32" + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.normalize_vector_distance is False + assert tool._config.sort_by is None + assert tool._config.hybrid_policy is None + assert tool._config.batch_size is None + assert tool._config.ef_runtime is None + assert tool._config.epsilon is None + assert tool._config.search_window_size is None + assert tool._config.use_search_history is None + assert tool._config.search_buffer_capacity is None + # Tool-level defaults + assert tool._filter_expression is None + + def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): + """Test custom parameter values via config object.""" + config = RedisVectorQueryConfig( + vector_field_name="custom_embedding", + num_results=20, + dtype="float64", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + hybrid_policy="BATCHES", + batch_size=100, + ef_runtime=200, + epsilon=0.01, + search_window_size=50, + use_search_history="ON", + search_buffer_capacity=1000, + ) + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "content", "url"], + ) + assert tool._config.vector_field_name == "custom_embedding" + assert tool._config.num_results == 20 + assert tool._return_fields == ["title", "content", "url"] + assert tool._config.dtype == "float64" + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.normalize_vector_distance is True + assert tool._config.hybrid_policy == "BATCHES" + assert tool._config.batch_size == 100 + assert tool._config.ef_runtime == 200 + assert tool._config.epsilon == 0.01 + assert tool._config.search_window_size == 50 + assert tool._config.use_search_history == "ON" + assert tool._config.search_buffer_capacity == 1000 + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_search", + description="Custom search description", + ) + assert tool.name == "custom_search" + assert tool.description == "Custom search description" + + +class TestRedisVectorSearchToolBuildQuery: + """Tests for _build_query method.""" + + def test_build_query_basic(self, vector_search_tool): + """Test basic query building.""" + embedding = [0.1] * 384 + query = vector_search_tool._build_query("test query", embedding) + + assert isinstance(query, VectorQuery) + # VectorQuery uses private attributes + assert query._vector == embedding + assert query._vector_field_name == "embedding" + assert query._num_results == 5 + + def test_build_query_with_num_results_override(self, vector_search_tool): + """Test query building with num_results override.""" + embedding = [0.1] * 384 + query = vector_search_tool._build_query( + "test query", embedding, num_results=15 + ) + + assert query._num_results == 15 + + def test_build_query_passes_core_parameters( + self, mock_index, mock_vectorizer + ): + """Test that core VectorQuery parameters are passed through.""" + # Use a string filter expression (valid type for VectorQuery) + filter_str = "@category:{redis}" + + config = RedisVectorQueryConfig( + vector_field_name="vec", + num_results=10, + dtype="float16", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + ) + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title"], + filter_expression=filter_str, + ) + + embedding = [0.1] * 384 + query = tool._build_query("test", embedding) + + # Check core attributes that are exposed on VectorQuery + assert query._vector_field_name == "vec" + assert query._num_results == 10 + assert query._return_fields == ["title"] + assert query._dtype == "float16" + assert query._dialect == 3 + assert query._in_order is True + + def test_stores_optional_parameters_in_config( + self, mock_index, mock_vectorizer + ): + """Test that optional parameters are stored correctly in config.""" + config = RedisVectorQueryConfig( + hybrid_policy="ADHOC_BF", + batch_size=50, + ef_runtime=100, + epsilon=0.05, + search_window_size=25, + use_search_history="AUTO", + search_buffer_capacity=500, + ) + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + ) + + # Verify parameters are stored in the config + assert tool._config.hybrid_policy == "ADHOC_BF" + assert tool._config.batch_size == 50 + assert tool._config.ef_runtime == 100 + assert tool._config.epsilon == 0.05 + assert tool._config.search_window_size == 25 + assert tool._config.use_search_history == "AUTO" + assert tool._config.search_buffer_capacity == 500 + + +class TestRedisVectorSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, vector_search_tool): + """Test function declaration generation.""" + declaration = vector_search_tool._get_declaration() + + assert declaration.name == "redis_vector_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required + + +class TestRedisVectorSearchToolImportError: + """Tests for import error handling.""" + + def test_import_error_message(self): + """Test that helpful error message is shown when redisvl not installed.""" + with patch.dict("sys.modules", {"redisvl": None}): + # This test verifies the error message format + # The actual import error is tested by the try/except in __init__.py + pass + + +class TestRedisVectorSearchToolFilterExpression: + """Tests for filter expression handling.""" + + def test_filter_expression_stored_on_tool(self, mock_index, mock_vectorizer): + """Test that filter expression is stored on the tool.""" + filter_str = "@category:{redis}" + + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + filter_expression=filter_str, + ) + + assert tool._filter_expression == filter_str + + def test_filter_expression_passed_to_query(self, mock_index, mock_vectorizer): + """Test that filter expression is correctly passed to VectorQuery.""" + filter_str = "@category:{redis}" + + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + filter_expression=filter_str, + ) + + embedding = [0.1] * 384 + query = tool._build_query("test", embedding) + + # VectorQuery wraps string filters in FilterExpression + assert str(query._filter_expression) == filter_str