From 8b6ad10450da11c6d671fcae3bfa351e90a8df60 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Mon, 22 Dec 2025 15:22:26 -0500 Subject: [PATCH 01/16] feat: add RedisVL search tools - Add RedisVectorSearchTool for KNN semantic search - Add RedisTextSearchTool for BM25 keyword search - Add RedisRangeSearchTool for threshold-based search - Add RedisHybridSearchTool for combined vector + text search - Add redis-vl optional dependency (redisvl, nltk, sentence-transformers) - Add sample agent demonstrating all search tools - Add unit tests for all tools --- .../samples/redis_vl_search/README.md | 205 +++++++++++++ .../samples/redis_vl_search/load_data.py | 259 +++++++++++++++++ .../redis_vl_search_agent/__init__.py | 20 ++ .../redis_vl_search_agent/agent.py | 154 ++++++++++ .../samples/redis_vl_search/schema.yaml | 50 ++++ pyproject.toml | 5 + src/google/adk_community/tools/__init__.py | 62 ++++ .../adk_community/tools/redis/__init__.py | 59 ++++ .../tools/redis/base_search_tool.py | 155 ++++++++++ .../tools/redis/hybrid_search_tool.py | 182 ++++++++++++ .../tools/redis/range_search_tool.py | 185 ++++++++++++ .../tools/redis/text_search_tool.py | 187 ++++++++++++ .../tools/redis/vector_search_tool.py | 213 ++++++++++++++ tests/unittests/tools/__init__.py | 0 tests/unittests/tools/redis/__init__.py | 0 .../tools/redis/test_hybrid_search_tool.py | 176 +++++++++++ .../tools/redis/test_range_search_tool.py | 157 ++++++++++ .../tools/redis/test_text_search_tool.py | 174 +++++++++++ .../tools/redis/test_vector_search_tool.py | 274 ++++++++++++++++++ 19 files changed, 2517 insertions(+) create mode 100644 contributing/samples/redis_vl_search/README.md create mode 100644 contributing/samples/redis_vl_search/load_data.py create mode 100644 contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py create mode 100644 contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py create mode 100644 contributing/samples/redis_vl_search/schema.yaml create mode 100644 src/google/adk_community/tools/__init__.py create mode 100644 src/google/adk_community/tools/redis/__init__.py create mode 100644 src/google/adk_community/tools/redis/base_search_tool.py create mode 100644 src/google/adk_community/tools/redis/hybrid_search_tool.py create mode 100644 src/google/adk_community/tools/redis/range_search_tool.py create mode 100644 src/google/adk_community/tools/redis/text_search_tool.py create mode 100644 src/google/adk_community/tools/redis/vector_search_tool.py create mode 100644 tests/unittests/tools/__init__.py create mode 100644 tests/unittests/tools/redis/__init__.py create mode 100644 tests/unittests/tools/redis/test_hybrid_search_tool.py create mode 100644 tests/unittests/tools/redis/test_range_search_tool.py create mode 100644 tests/unittests/tools/redis/test_text_search_tool.py create mode 100644 tests/unittests/tools/redis/test_vector_search_tool.py diff --git a/contributing/samples/redis_vl_search/README.md b/contributing/samples/redis_vl_search/README.md new file mode 100644 index 0000000..4b44178 --- /dev/null +++ b/contributing/samples/redis_vl_search/README.md @@ -0,0 +1,205 @@ +# RedisVL Search Agent + +This sample demonstrates using Redis search tools to give an ADK agent +access to a Redis-based knowledge base with multiple search capabilities. + +## What This Sample Shows + +- Setting up a Redis vector index with a schema +- Using 3 Redis search tools in one agent (4th requires Redis 8.4+): + - **RedisVectorSearchTool**: Semantic similarity search (KNN) - finds conceptually similar content + - **RedisTextSearchTool**: Full-text keyword search (BM25) - matches exact terms and phrases + - **RedisRangeSearchTool**: Distance threshold search - returns ALL docs within a relevance radius + - **RedisHybridSearchTool**: Combined vector + text search (requires Redis 8.4+ and redis-py 7.1+) +- Integrating RedisVL with an ADK agent + +## Prerequisites + +1. **Redis Stack** running locally (or Redis Cloud with Search capability) + ```bash + # Using Docker + docker run -d --name redis-stack -p 6379:6379 redis/redis-stack:latest + ``` + +2. **No API keys needed for embeddings** - uses Redis' open-source `redis/langcache-embed-v2` model (768 dimensions) + +## Setup + +1. Install dependencies: + ```bash + pip install "google-adk-community[redis-vl]" + ``` + +2. Download NLTK stopwords (required for keyword search): + ```bash + python -c "import nltk; nltk.download('stopwords')" + ``` + +3. Set environment variables (or create a `.env` file): + ```bash + export REDIS_URL=redis://localhost:6379 + export GOOGLE_API_KEY=your-google-api-key # For Gemini LLM + ``` + +4. Load sample data into Redis: + ```bash + cd contributing/samples/redis_vl_search + python load_data.py + ``` + +5. Run the agent: + ```bash + cd contributing/samples/redis_vl_search + adk web + ``` + +## Files + +| File | Description | +|------|-------------| +| `schema.yaml` | Redis index schema defining document structure | +| `load_data.py` | Script to populate Redis with sample documents | +| `redis_vl_search_agent/agent.py` | Agent definition with all Redis search tools | + +## How It Works + +1. **Schema Definition** (`schema.yaml`): Defines the index structure with fields + for title, content, URL, category, and a vector embedding field. + +2. **Data Loading** (`load_data.py`): Populates Redis with sample documents about + Redis and ADK, embedding the content using Redis' langcache-embed-v2 model. + +3. **Agent** (`redis_vl_search_agent/agent.py`): Creates an agent with access to + multiple search tools for different use cases. + +## Search Tools + +### semantic_search (RedisVectorSearchTool) +**Best for:** Conceptual questions, natural language queries, finding similar content. + +**How it works:** Converts query to vector embedding, finds K nearest neighbors by cosine similarity. + +**Returns:** Top-K most similar documents (default: 5). + +**Example queries:** +- "What is Redis?" → finds docs about Redis even if they don't say "What is Redis" +- "How do I build a chatbot?" → finds docs about "intelligent assistants", "conversational AI" +- "Fast database for caching" → finds Redis docs even without exact keyword match + +### keyword_search (RedisTextSearchTool) +**Best for:** Exact terms, acronyms, technical jargon, API names, error messages. + +**How it works:** BM25 text scoring algorithm - matches exact tokens, weighs by term frequency. + +**Returns:** Top-K documents ranked by keyword relevance. + +**Example queries:** +- "HNSW algorithm" → exact match on "HNSW" acronym +- "BM25 formula" → finds docs containing "BM25" +- "VectorQuery class" → API/class name lookup +- "RRF ranking" → technical term that needs exact match + +### range_search (RedisRangeSearchTool) +**Best for:** Exhaustive retrieval, comprehensive coverage, finding ALL related documents. + +**How it works:** Returns ALL documents within a distance threshold (not just top-K). + +**Returns:** Variable number - every document above the relevance bar. + +**Use when:** +- User wants "everything" about a topic +- Comprehensive research needed +- Quality filtering (only highly relevant docs) +- Clustering/grouping similar content + +**Example queries:** +- "Tell me everything about RAG pipelines" → returns all RAG-related docs +- "All Redis data structures" → comprehensive list +- "Complete guide to embeddings" → exhaustive retrieval + +### hybrid_search (RedisHybridSearchTool) +**Best for:** Queries that benefit from both semantic understanding AND exact keyword matching. + +**How it works:** Combines vector similarity + BM25 text scores using RRF or linear weighting. + +**Requires:** Redis 8.4+ and redis-py 7.1+ + +## Example Queries + +Once running, try asking the agent: + +| Query | Expected Tool | Why | +|-------|---------------|-----| +| "What is Redis?" | semantic_search | Conceptual question | +| "HNSW algorithm details" | keyword_search | Technical acronym | +| "Tell me everything about RAG" | range_search | Exhaustive retrieval | +| "How do I build a chatbot?" | semantic_search | Natural language | +| "BM25 formula" | keyword_search | Exact term lookup | +| "All vector search methods" | range_search | Comprehensive coverage | + +## Customization + +### Using a Different Vectorizer + +```python +from redisvl.utils.vectorize import HuggingFaceTextVectorizer + +vectorizer = HuggingFaceTextVectorizer(model="sentence-transformers/all-MiniLM-L6-v2") +``` + +Note: Update `dims` in `schema.yaml` to match your model's embedding dimensions. + +### Adding Filters + +You can add filter expressions to narrow search results: + +```python +from redisvl.query.filter import Tag + +redis_search = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=5, + return_fields=["title", "content", "url", "category"], + filter_expression=Tag("category") == "redis", # Only search Redis docs +) +``` + +See [RedisVL Filter documentation](https://docs.redisvl.com/api/filter.html) for more filter options. + +### Advanced Query Options + +`RedisVectorSearchTool` exposes all VectorQuery parameters: + +```python +redis_search = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=10, + return_fields=["title", "content"], + # Query tuning + dtype="float32", # Vector dtype + return_score=True, # Include similarity score + normalize_vector_distance=True, # Convert to 0-1 similarity + # Hybrid filtering + filter_expression=Tag("category") == "redis", + hybrid_policy="BATCHES", # or "ADHOC_BF" + batch_size=100, # For BATCHES policy + # HNSW tuning + ef_runtime=150, # Higher = better recall, slower + epsilon=0.01, # Range search approximation + # SVS-VAMANA tuning + search_window_size=20, # Search window size + use_search_history="AUTO", # "OFF", "ON", or "AUTO" + search_buffer_capacity=30, # 2-level compression tuning +) +``` + +See [RedisVL Query documentation](https://docs.redisvl.com/api/query.html) for details. + +### Connecting to Redis Cloud + +```bash +export REDIS_URL=redis://default:password@your-redis-cloud-host:port +``` + diff --git a/contributing/samples/redis_vl_search/load_data.py b/contributing/samples/redis_vl_search/load_data.py new file mode 100644 index 0000000..c175413 --- /dev/null +++ b/contributing/samples/redis_vl_search/load_data.py @@ -0,0 +1,259 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load sample data into Redis for the redis_search_agent demo. + +Data is designed to demonstrate when different search tools excel: +- Semantic search: conceptual queries, synonyms, paraphrasing +- Keyword search: technical terms, acronyms, exact phrases +- Range search: finding all highly relevant docs above threshold +- Hybrid search: combining both approaches +""" + +import os +from pathlib import Path + +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +# Documents designed to showcase different search tool strengths +SAMPLE_DOCS = [ + # === SEMANTIC SEARCH DEMOS === + # These use varied vocabulary - semantic search finds them even with different words + { + "title": "Introduction to Redis", + "content": ( + "Redis is a lightning-fast in-memory data store. It excels at caching, " + "session management, and real-time analytics. Think of it as a Swiss Army " + "knife for data - versatile, quick, and reliable." + ), + "url": "https://redis.io/docs/about/", + "category": "redis", + "doc_type": "reference", + "difficulty": "beginner", + }, + { + "title": "Understanding Vector Databases", + "content": ( + "Vector databases store numerical representations of data called embeddings. " + "These embeddings capture semantic meaning, enabling similarity search. " + "Applications include recommendation engines, image search, and chatbots." + ), + "url": "https://redis.io/docs/vectors/", + "category": "concepts", + "doc_type": "reference", + "difficulty": "intermediate", + }, + { + "title": "Building Intelligent Assistants", + "content": ( + "Modern AI assistants combine language models with external knowledge. " + "They can search databases, call APIs, and maintain conversation context. " + "The key is giving them the right tools for each task." + ), + "url": "https://google.github.io/adk-docs/agents/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + # === KEYWORD SEARCH DEMOS === + # These contain specific technical terms that keyword search finds precisely + { + "title": "HNSW Algorithm Deep Dive", + "content": ( + "HNSW (Hierarchical Navigable Small World) is the algorithm Redis uses for " + "approximate nearest neighbor search. It builds a multi-layer graph where " + "each layer has exponentially fewer nodes. Search starts at the top layer " + "and navigates down. Parameters: M (connections per node), EF (search width)." + ), + "url": "https://redis.io/docs/hnsw/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "BM25 Scoring Explained", + "content": ( + "BM25 (Best Matching 25) is a ranking function for full-text search. " + "It improves on TF-IDF by adding document length normalization and term " + "frequency saturation. Redis supports BM25STD and BM25 scorers. " + "Formula: score = IDF * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl/avgdl))" + ), + "url": "https://redis.io/docs/bm25/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "RRF: Reciprocal Rank Fusion", + "content": ( + "RRF (Reciprocal Rank Fusion) combines results from multiple search methods. " + "Score = sum(1 / (k + rank)) across all methods. Default k=60. " + "Used in hybrid search to merge vector and keyword results. " + "Alternative: LINEAR combination with alpha weighting." + ), + "url": "https://redis.io/docs/hybrid/", + "category": "redis", + "doc_type": "reference", + "difficulty": "advanced", + }, + { + "title": "RedisVL API Reference", + "content": ( + "Key classes: SearchIndex, AsyncSearchIndex, VectorQuery, TextQuery, " + "HybridQuery, VectorRangeQuery. Vectorizers: HFTextVectorizer (recommended, " + "uses redis/langcache-embed-v2 model, 768 dimensions). Methods: index.query(), " + "index.search(), index.load(), index.create(overwrite=True)." + ), + "url": "https://redis.io/docs/redisvl/api/", + "category": "redis", + "doc_type": "api", + "difficulty": "intermediate", + }, + # === RANGE SEARCH DEMOS === + # Highly related docs - range search finds all above threshold + { + "title": "RAG Architecture Overview", + "content": ( + "Retrieval-Augmented Generation (RAG) enhances LLMs with external knowledge. " + "Step 1: Embed the user query. Step 2: Search vector database for relevant " + "documents. Step 3: Include retrieved context in LLM prompt. " + "Step 4: Generate grounded response." + ), + "url": "https://redis.io/solutions/rag/", + "category": "concepts", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "RAG Best Practices", + "content": ( + "Tips for effective RAG: chunk documents appropriately (512-1024 tokens), " + "use hybrid search for better recall, rerank results before prompting, " + "include metadata for filtering, monitor retrieval quality metrics." + ), + "url": "https://redis.io/solutions/rag/best-practices/", + "category": "concepts", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "RAG vs Fine-tuning", + "content": ( + "RAG retrieves knowledge at query time; fine-tuning bakes it into weights. " + "RAG pros: up-to-date info, traceable sources, no training cost. " + "Fine-tuning pros: faster inference, specialized behavior. " + "Often combined: fine-tune for style, RAG for facts." + ), + "url": "https://redis.io/solutions/rag/comparison/", + "category": "concepts", + "doc_type": "reference", + "difficulty": "advanced", + }, + # === ADK SPECIFIC === + { + "title": "ADK Tool Development", + "content": ( + "Create custom tools by subclassing BaseTool. Implement run_async() for " + "the tool logic and _get_declaration() for the schema. Tools receive " + "args dict and ToolContext. Return results as dict or string." + ), + "url": "https://google.github.io/adk-docs/tools/custom/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "intermediate", + }, + { + "title": "ADK Agent Orchestration", + "content": ( + "Agents can delegate to sub-agents using transfer_to_agent(). " + "Define agent hierarchy with parent/child relationships. " + "Use instruction prompts to guide when delegation occurs. " + "Sub-agents inherit tools from parent unless overridden." + ), + "url": "https://google.github.io/adk-docs/agents/orchestration/", + "category": "adk", + "doc_type": "tutorial", + "difficulty": "advanced", + }, + # === FAQ STYLE (good for both search types) === + { + "title": "FAQ: Redis Connection Issues", + "content": ( + "Q: Connection refused error? A: Check Redis is running on the correct " + "port (default 6379). Q: Authentication failed? A: Set REDIS_PASSWORD " + "env var or pass password to connect(). Q: Timeout errors? A: Increase " + "socket_timeout parameter or check network latency." + ), + "url": "https://redis.io/docs/faq/connection/", + "category": "redis", + "doc_type": "faq", + "difficulty": "beginner", + }, + { + "title": "FAQ: Embedding Dimensions Mismatch", + "content": ( + "Q: Dimension mismatch error? A: Ensure query embeddings match index " + "dimensions. Common dimensions: OpenAI ada-002 (1536), langcache-embed-v2 " + "(768), sentence-transformers (384-768). Check schema.yaml dims field." + ), + "url": "https://redis.io/docs/faq/vectors/", + "category": "redis", + "doc_type": "faq", + "difficulty": "beginner", + }, +] + + +def load_data() -> None: + """Load sample documents into Redis with embeddings.""" + schema_path = Path(__file__).parent / "schema.yaml" + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + + print(f"Connecting to Redis at {redis_url}") + index = SearchIndex.from_yaml(str(schema_path)) + index.connect(redis_url) + + print("Creating index (will overwrite if exists)...") + index.create(overwrite=True) + + print("Generating embeddings...") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + docs_with_embeddings = [] + for doc in SAMPLE_DOCS: + embedding = vectorizer.embed(doc["content"], as_buffer=True) + docs_with_embeddings.append({**doc, "embedding": embedding}) + print(f" [{doc['doc_type']:9}] {doc['title']}") + + print("\nLoading documents into Redis...") + index.load(docs_with_embeddings) + + print(f""" +Loaded {len(SAMPLE_DOCS)} documents. + +Try these queries to see different tools in action: + - "What is Redis?" → semantic_search (conceptual) + - "HNSW algorithm" → keyword_search (exact term) + - "Tell me everything about RAG" → range_search (exhaustive retrieval) + - "BM25 formula" → keyword_search (technical) + - "building chatbots" → semantic_search (synonym for assistants) + +Run: adk web +""") + + +if __name__ == "__main__": + load_data() + diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py new file mode 100644 index 0000000..ca65aef --- /dev/null +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis search agent sample.""" + +from .agent import root_agent + +__all__ = ["root_agent"] + diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py new file mode 100644 index 0000000..dc508c2 --- /dev/null +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py @@ -0,0 +1,154 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis search agent demonstrating vector, text, range, and hybrid search tools.""" + +import os +from pathlib import Path + +from dotenv import load_dotenv +from google.adk import Agent +from google.adk_community.tools.redis import ( + RedisRangeSearchTool, + RedisTextSearchTool, + RedisVectorSearchTool, +) +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +SCHEMA_PATH = Path(__file__).parent.parent / "schema.yaml" +RETURN_FIELDS = ["title", "content", "url", "category", "doc_type", "difficulty"] + +INSTRUCTION = """You are a helpful assistant with a technical knowledge base. + +You have 3 search tools, each optimized for different query types: + +## semantic_search (Vector KNN) +Best for: Conceptual questions, natural language queries, finding similar content. +How it works: Converts query to vector embedding, finds K nearest neighbors by cosine similarity. +Returns: Top-K most similar documents (default: 5). +Example queries: + - "What is Redis?" → finds docs about Redis even if they don't say "What is Redis" + - "How do I build a chatbot?" → finds docs about "intelligent assistants", "conversational AI" + - "Fast database for caching" → finds Redis docs even without exact keyword match + - "Machine learning embeddings" → finds semantically related vector/AI content + +## keyword_search (BM25 Full-Text) +Best for: Exact terms, acronyms, technical jargon, API names, error messages. +How it works: BM25 text scoring algorithm - matches exact tokens, weighs by term frequency. +Returns: Top-K documents ranked by keyword relevance. +Example queries: + - "HNSW algorithm" → exact match on "HNSW" acronym + - "BM25 formula" → finds docs containing "BM25" + - "VectorQuery class" → API/class name lookup + - "RRF ranking" → technical term that needs exact match + - "ConnectionError redis" → error message lookup + +## range_search (Vector Range/Threshold) +Best for: Exhaustive retrieval, comprehensive coverage, finding ALL related documents. +How it works: Returns ALL documents within a distance threshold (not just top-K). +Returns: Variable number - every document above the relevance bar. +Use when: + - User wants "everything" about a topic: "Tell me everything about RAG" + - Comprehensive research: "All vector search algorithms" + - Quality filtering: Only docs highly relevant to the query + - Clustering/grouping: Find all docs in a semantic neighborhood + - Deduplication: Find near-duplicate content +Example queries: + - "Tell me everything about RAG pipelines" → returns all RAG-related docs + - "All Redis data structures" → comprehensive list + - "Complete guide to embeddings" → exhaustive retrieval + - "Every caching strategy" → all related documents, not just top 5 + +## Strategy +1. Start with the most appropriate tool based on query type +2. If results are poor or incomplete, try another tool +3. For broad topics, consider range_search first, then refine with others +4. For technical terms/acronyms, prefer keyword_search +5. For natural questions, prefer semantic_search + +Always cite sources with title and difficulty level.""" + + +def get_index(schema_path: Path, redis_url: str) -> SearchIndex: + """Create and connect to Redis search index.""" + index = SearchIndex.from_yaml(str(schema_path)) + index.connect(redis_url) + return index + + +def get_search_tools(index: SearchIndex, vectorizer: HFTextVectorizer) -> list: + """Create search tools for the agent.""" + return [ + RedisVectorSearchTool( + name="semantic_search", + description="Semantic similarity search for conceptual queries.", + index=index, + vectorizer=vectorizer, + num_results=5, + return_fields=RETURN_FIELDS, + ), + RedisTextSearchTool( + name="keyword_search", + description="Keyword search for exact terms and phrases.", + index=index, + text_field_name="content", + num_results=5, + return_fields=RETURN_FIELDS, + text_scorer="BM25STD", + ), + RedisRangeSearchTool( + name="range_search", + description=( + "Vector range search - returns ALL documents within a semantic distance " + "threshold (not just top-K). Use for: exhaustive topic retrieval, " + "comprehensive coverage, quality-filtered results, finding all related " + "documents, clustering similar content, or when you need everything " + "above a relevance bar rather than a fixed number of results." + ), + index=index, + vectorizer=vectorizer, + distance_threshold=0.5, + return_fields=RETURN_FIELDS, + ), + ] + + +def create_agent() -> Agent: + """Create the Redis search agent.""" + load_dotenv() + + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") + index = get_index(SCHEMA_PATH, redis_url) + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + tools = get_search_tools(index, vectorizer) + + return Agent( + model="gemini-2.0-flash-001", + name="root_agent", + instruction=INSTRUCTION, + tools=tools, + ) + + +# Required for `adk web` to discover the agent +root_agent = create_agent() + + +if __name__ == "__main__": + # For local testing + print(f"Agent '{root_agent.name}' created with {len(root_agent.tools)} tools") + for tool in root_agent.tools: + print(f" - {tool.name}") + diff --git a/contributing/samples/redis_vl_search/schema.yaml b/contributing/samples/redis_vl_search/schema.yaml new file mode 100644 index 0000000..4049345 --- /dev/null +++ b/contributing/samples/redis_vl_search/schema.yaml @@ -0,0 +1,50 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Redis index schema for the knowledge base sample. +# Supports vector, text, range, and hybrid search demonstrations. + +version: "0.1.0" + +index: + name: adk_knowledge_base + prefix: doc + +fields: + - name: title + type: text + + - name: content + type: text + + - name: url + type: tag + + - name: category + type: tag # redis, adk, concepts, tutorials + + - name: doc_type + type: tag # reference, tutorial, faq, api + + - name: difficulty + type: tag # beginner, intermediate, advanced + + - name: embedding + type: vector + attrs: + algorithm: hnsw + dims: 768 + distance_metric: cosine + datatype: float32 + diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..fb43714 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,11 @@ test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] +redis-vl = [ + "redisvl>=0.4.0", + "nltk>=3.8.0", + "sentence-transformers>=2.2.0", +] [tool.pyink] diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py new file mode 100644 index 0000000..f737b7a --- /dev/null +++ b/src/google/adk_community/tools/__init__.py @@ -0,0 +1,62 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools module for ADK Community integrations. + +This module uses lazy loading to provide helpful error messages when +optional dependencies are not installed. +""" + +__all__ = [ + "BaseRedisSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", +] + +# Redis tool names for lazy loading +_REDIS_TOOLS = { + "BaseRedisSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", +} + + +def __getattr__(name: str): + """Lazy load tools to provide helpful error messages.""" + if name in _REDIS_TOOLS: + try: + from .redis import BaseRedisSearchTool + from .redis import RedisHybridSearchTool + from .redis import RedisRangeSearchTool + from .redis import RedisTextSearchTool + from .redis import RedisVectorSearchTool + + globals().update({ + "BaseRedisSearchTool": BaseRedisSearchTool, + "RedisVectorSearchTool": RedisVectorSearchTool, + "RedisHybridSearchTool": RedisHybridSearchTool, + "RedisRangeSearchTool": RedisRangeSearchTool, + "RedisTextSearchTool": RedisTextSearchTool, + }) + return globals()[name] + except ImportError as e: + raise ImportError( + f"{name} requires redisvl. " + "Install with: pip install google-adk-community[redis-vl]" + ) from e + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/google/adk_community/tools/redis/__init__.py b/src/google/adk_community/tools/redis/__init__.py new file mode 100644 index 0000000..027592f --- /dev/null +++ b/src/google/adk_community/tools/redis/__init__.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis tools for ADK Community using RedisVL. + +This module provides tools for Redis-based search operations: + +- `RedisVectorSearchTool`: KNN vector similarity search +- `RedisHybridSearchTool`: Combined vector + BM25 text search +- `RedisRangeSearchTool`: Distance threshold-based vector search +- `RedisTextSearchTool`: Full-text BM25 keyword search + +Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from google.adk_community.tools.redis import RedisVectorSearchTool + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=5, + ) + ``` +""" + +try: + from .base_search_tool import BaseRedisSearchTool + from .hybrid_search_tool import RedisHybridSearchTool + from .range_search_tool import RedisRangeSearchTool + from .text_search_tool import RedisTextSearchTool + from .vector_search_tool import RedisVectorSearchTool +except ImportError as e: + raise ImportError( + "Redis tools require redisvl. " + "Install with: pip install google-adk-community[redis-vl]" + ) from e + +__all__ = [ + "BaseRedisSearchTool", + "RedisVectorSearchTool", + "RedisHybridSearchTool", + "RedisRangeSearchTool", + "RedisTextSearchTool", +] diff --git a/src/google/adk_community/tools/redis/base_search_tool.py b/src/google/adk_community/tools/redis/base_search_tool.py new file mode 100644 index 0000000..24d303f --- /dev/null +++ b/src/google/adk_community/tools/redis/base_search_tool.py @@ -0,0 +1,155 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for Redis search tools using RedisVL.""" + +from __future__ import annotations + +from abc import abstractmethod +import asyncio +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from google.adk.tools import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import BaseVectorizer + + +class BaseRedisSearchTool(BaseTool): + """Base class for Redis search tools using RedisVL. + + This class provides common functionality for all Redis search tools: + - Index and vectorizer management + - Common error handling + - Standard response format + + Subclasses must implement `_build_query()` to create the appropriate + RedisVL query object for their search type. + """ + + def __init__( + self, + *, + name: str, + description: str, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: Optional[BaseVectorizer] = None, + return_fields: Optional[List[str]] = None, + ): + """Initialize the base Redis search tool. + + Args: + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: Optional vectorizer for embedding queries. + return_fields: Optional list of fields to return in results. + """ + super().__init__(name=name, description=description) + self._index = index + self._vectorizer = vectorizer + self._return_fields = return_fields + self._is_async_index = isinstance(index, AsyncSearchIndex) + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM. + + Returns a simple interface with just a query parameter. + Subclasses can override to add additional parameters. + """ + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + }, + required=["query"], + ), + ) + + @abstractmethod + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> Any: + """Build the RedisVL query object. + + Args: + query_text: The original query text from the user. + embedding: The vector embedding of the query text. + **kwargs: Additional parameters from the LLM call. + + Returns: + A RedisVL query object (VectorQuery, HybridQuery, etc.) + """ + pass + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> Dict[str, Any]: + """Execute the search query. + + Args: + args: Arguments from the LLM, must include 'query'. + tool_context: The tool execution context. + + Returns: + A dictionary with status, count, and results. + """ + query_text = args.get("query", "") + + if not query_text: + return {"status": "error", "error": "Query text is required."} + + try: + # Embed the query text + if self._vectorizer is None: + return { + "status": "error", + "error": "Vectorizer is required for this search type.", + } + + embedding = await self._vectorizer.aembed(query_text) + + # Build the query (subclass-specific) + redisvl_query = self._build_query(query_text, embedding, **args) + + # Execute the query - handle both sync and async indexes + if self._is_async_index: + results = await self._index.query(redisvl_query) + else: + # Run sync query in thread pool to avoid blocking + results = await asyncio.to_thread(self._index.query, redisvl_query) + + # Format results + formatted_results = [dict(r) for r in results] if results else [] + + return { + "status": "success", + "count": len(formatted_results), + "results": formatted_results, + } + + except Exception as e: + return {"status": "error", "error": str(e)} diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py new file mode 100644 index 0000000..50492ff --- /dev/null +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -0,0 +1,182 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis hybrid search tool combining vector similarity and BM25 text search.""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Union + +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import HybridQuery +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import BaseRedisSearchTool + + +class RedisHybridSearchTool(BaseRedisSearchTool): + """Hybrid search tool combining vector similarity and BM25 text search. + + This tool performs a hybrid search that combines semantic vector similarity + with keyword-based BM25 text matching. This is useful when you want to + leverage both the semantic understanding of embeddings and the precision + of keyword matching. + + Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from google.adk_community.tools.redis import RedisHybridSearchTool + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + tool = RedisHybridSearchTool( + index=index, + vectorizer=vectorizer, + text_field_name="content", + linear_alpha=0.7, # 70% text, 30% vector + num_results=10, + return_fields=["title", "content"], + ) + + agent = Agent(model="gemini-2.0-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + text_field_name: str = "content", + vector_field_name: str = "embedding", + text_scorer: str = "BM25STD", + combination_method: Optional[str] = None, + linear_alpha: float = 0.3, + rrf_window: int = 20, + rrf_constant: int = 60, + num_results: int = 10, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + dtype: str = "float32", + stopwords: Optional[Union[str, Set[str]]] = "english", + text_weights: Optional[Dict[str, float]] = None, + name: str = "redis_hybrid_search", + description: str = "Search using both semantic similarity and keyword matching.", + ): + """Initialize the hybrid search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries. + text_field_name: The name of the text field for BM25 search. + vector_field_name: The name of the vector field for similarity search. + text_scorer: The text scoring algorithm (default: "BM25STD"). + combination_method: Score combination method - "RRF" or "LINEAR". + linear_alpha: Weight of text score when using LINEAR (default: 0.3). + rrf_window: Window size for RRF combination (default: 20). + rrf_constant: Constant for RRF combination (default: 60). + num_results: Default number of results to return (default: 10). + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + dtype: The dtype of the vector (default: "float32"). + stopwords: Stopwords to remove from query (default: "english"). + text_weights: Optional field weights for text scoring. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + self._text_field_name = text_field_name + self._vector_field_name = vector_field_name + self._text_scorer = text_scorer + self._combination_method = combination_method + self._linear_alpha = linear_alpha + self._rrf_window = rrf_window + self._rrf_constant = rrf_constant + self._num_results = num_results + self._filter_expression = filter_expression + self._dtype = dtype + self._stopwords = stopwords + self._text_weights = text_weights + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._num_results})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> HybridQuery: + """Build a HybridQuery for combined vector + text search. + + Args: + query_text: The original query text for BM25 matching. + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., num_results). + + Returns: + A HybridQuery configured for hybrid search. + """ + num_results = kwargs.get("num_results", self._num_results) + + return HybridQuery( + text=query_text, + text_field_name=self._text_field_name, + vector=embedding, + vector_field_name=self._vector_field_name, + text_scorer=self._text_scorer, + combination_method=self._combination_method, + linear_alpha=self._linear_alpha, + rrf_window=self._rrf_window, + rrf_constant=self._rrf_constant, + filter_expression=self._filter_expression, + dtype=self._dtype, + num_results=num_results, + return_fields=self._return_fields, + stopwords=self._stopwords, + text_weights=self._text_weights, + ) diff --git a/src/google/adk_community/tools/redis/range_search_tool.py b/src/google/adk_community/tools/redis/range_search_tool.py new file mode 100644 index 0000000..cebe512 --- /dev/null +++ b/src/google/adk_community/tools/redis/range_search_tool.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis vector range search tool using distance threshold.""" + +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import VectorRangeQuery +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import BaseRedisSearchTool + +# Type alias for sort specification +SortSpec = Optional[ + Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] +] + + +class RedisRangeSearchTool(BaseRedisSearchTool): + """Vector range search tool using distance threshold. + + This tool finds all documents within a specified distance threshold + from the query vector. Unlike KNN search which returns a fixed number + of results, range search returns all documents that are "close enough" + based on the threshold. + + Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from google.adk_community.tools.redis import RedisRangeSearchTool + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + tool = RedisRangeSearchTool( + index=index, + vectorizer=vectorizer, + distance_threshold=0.3, # Only return docs within 0.3 distance + return_fields=["title", "content"], + ) + + agent = Agent(model="gemini-2.0-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + vector_field_name: str = "embedding", + distance_threshold: float = 0.2, + num_results: int = 10, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + dtype: str = "float32", + return_score: bool = True, + dialect: int = 2, + sort_by: SortSpec = None, + in_order: bool = False, + epsilon: Optional[float] = None, + normalize_vector_distance: bool = False, + name: str = "redis_range_search", + description: str = "Find all documents within a similarity threshold.", + ): + """Initialize the range search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries. + vector_field_name: The name of the vector field in the index. + distance_threshold: Maximum distance for results (default: 0.2). + num_results: Maximum number of results to return (default: 10). + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + dtype: The dtype of the vector (default: "float32"). + return_score: Whether to return the vector distance (default: True). + dialect: The RediSearch query dialect (default: 2). + sort_by: Field(s) to order results by. + in_order: Require query terms in same order (default: False). + epsilon: Range search approximation factor for HNSW/SVS-VAMANA. + normalize_vector_distance: Convert distance to 0-1 similarity. + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + self._vector_field_name = vector_field_name + self._distance_threshold = distance_threshold + self._num_results = num_results + self._filter_expression = filter_expression + self._dtype = dtype + self._return_score = return_score + self._dialect = dialect + self._sort_by = sort_by + self._in_order = in_order + self._epsilon = epsilon + self._normalize_vector_distance = normalize_vector_distance + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "distance_threshold": types.Schema( + type=types.Type.NUMBER, + description=( + "Max distance threshold (default:" + f" {self._distance_threshold})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> VectorRangeQuery: + """Build a VectorRangeQuery for distance-based search. + + Args: + query_text: The original query text (unused for range search). + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., distance_threshold). + + Returns: + A VectorRangeQuery configured for range search. + """ + distance_threshold = kwargs.get( + "distance_threshold", self._distance_threshold + ) + + query_kwargs: dict[str, Any] = { + "vector": embedding, + "vector_field_name": self._vector_field_name, + "distance_threshold": distance_threshold, + "num_results": self._num_results, + "return_fields": self._return_fields, + "filter_expression": self._filter_expression, + "dtype": self._dtype, + "return_score": self._return_score, + "dialect": self._dialect, + "sort_by": self._sort_by, + "in_order": self._in_order, + "normalize_vector_distance": self._normalize_vector_distance, + } + + if self._epsilon is not None: + query_kwargs["epsilon"] = self._epsilon + + return VectorRangeQuery(**query_kwargs) diff --git a/src/google/adk_community/tools/redis/text_search_tool.py b/src/google/adk_community/tools/redis/text_search_tool.py new file mode 100644 index 0000000..4746aa9 --- /dev/null +++ b/src/google/adk_community/tools/redis/text_search_tool.py @@ -0,0 +1,187 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis full-text search tool using BM25.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from google.adk.tools import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import TextQuery + +# Type alias for sort specification +SortSpec = Optional[ + Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] +] + + +class RedisTextSearchTool(BaseTool): + """Full-text search tool using BM25 scoring. + + This tool performs keyword-based full-text search using BM25 scoring. + Unlike vector search, it doesn't require embeddings - it matches + documents based on keyword relevance. + + Example: + ```python + from redisvl.index import SearchIndex + from google.adk_community.tools.redis import RedisTextSearchTool + + index = SearchIndex.from_yaml("schema.yaml") + + tool = RedisTextSearchTool( + index=index, + text_field_name="content", + num_results=10, + return_fields=["title", "content"], + ) + + agent = Agent(model="gemini-2.0-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + text_field_name: str = "content", + text_scorer: str = "BM25STD", + num_results: int = 10, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + return_score: bool = True, + dialect: int = 2, + sort_by: SortSpec = None, + in_order: bool = False, + stopwords: Optional[Union[str, Set[str]]] = "english", + name: str = "redis_text_search", + description: str = "Search for documents using keyword matching.", + ): + """Initialize the text search tool. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + text_field_name: The name of the text field to search. + text_scorer: The text scoring algorithm (default: "BM25STD"). + num_results: Default number of results to return (default: 10). + return_fields: Optional list of fields to return in results. + filter_expression: Optional filter expression to narrow results. + return_score: Whether to return the text score (default: True). + dialect: The RediSearch query dialect (default: 2). + sort_by: Field(s) to order results by. + in_order: Require query terms in same order (default: False). + stopwords: Stopwords to remove from query (default: "english"). + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__(name=name, description=description) + self._index = index + self._text_field_name = text_field_name + self._text_scorer = text_scorer + self._num_results = num_results + self._return_fields = return_fields + self._filter_expression = filter_expression + self._return_score = return_score + self._dialect = dialect + self._sort_by = sort_by + self._in_order = in_order + self._stopwords = stopwords + self._is_async_index = isinstance(index, AsyncSearchIndex) + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._num_results})." + ), + ), + }, + required=["query"], + ), + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> Dict[str, Any]: + """Execute the text search query. + + Args: + args: Arguments from the LLM, must include 'query'. + tool_context: The tool execution context. + + Returns: + A dictionary with status, count, and results. + """ + query_text = args.get("query", "") + + if not query_text: + return {"status": "error", "error": "Query text is required."} + + try: + num_results = args.get("num_results", self._num_results) + + text_query = TextQuery( + text=query_text, + text_field_name=self._text_field_name, + text_scorer=self._text_scorer, + filter_expression=self._filter_expression, + return_fields=self._return_fields, + num_results=num_results, + return_score=self._return_score, + dialect=self._dialect, + sort_by=self._sort_by, + in_order=self._in_order, + stopwords=self._stopwords, + ) + + # Execute the query - handle both sync and async indexes + if self._is_async_index: + results = await self._index.query(text_query) + else: + results = await asyncio.to_thread(self._index.query, text_query) + + formatted_results = [dict(r) for r in results] if results else [] + + return { + "status": "success", + "count": len(formatted_results), + "results": formatted_results, + } + + except Exception as e: + return {"status": "error", "error": str(e)} diff --git a/src/google/adk_community/tools/redis/vector_search_tool.py b/src/google/adk_community/tools/redis/vector_search_tool.py new file mode 100644 index 0000000..860bdff --- /dev/null +++ b/src/google/adk_community/tools/redis/vector_search_tool.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Redis vector similarity search tool using RedisVL.""" + +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.genai import types +from redisvl.index import AsyncSearchIndex +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.utils.vectorize import BaseVectorizer + +from .base_search_tool import BaseRedisSearchTool + +# Type alias for sort specification +SortSpec = Optional[ + Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] +] + + +class RedisVectorSearchTool(BaseRedisSearchTool): + """Vector similarity search tool using RedisVL. + + This tool performs K-nearest neighbor (KNN) vector similarity search + over a Redis index. It embeds the query text using the provided + vectorizer and finds the most similar documents. + + Example: + ```python + from redisvl.index import SearchIndex + from redisvl.utils.vectorize import HFTextVectorizer + from redisvl.query.filter import Tag + from google.adk_community.tools.redis import RedisVectorSearchTool + + index = SearchIndex.from_yaml("schema.yaml") + vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + + tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=5, + return_fields=["title", "content", "url"], + filter_expression=Tag("category") == "redis", # Optional filter + ) + + # Use with an agent + agent = Agent(model="gemini-2.0-flash", tools=[tool]) + ``` + """ + + def __init__( + self, + *, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + vector_field_name: str = "embedding", + num_results: int = 10, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + dtype: str = "float32", + return_score: bool = True, + dialect: int = 2, + sort_by: SortSpec = None, + in_order: bool = False, + hybrid_policy: Optional[str] = None, + batch_size: Optional[int] = None, + ef_runtime: Optional[int] = None, + epsilon: Optional[float] = None, + search_window_size: Optional[int] = None, + use_search_history: Optional[str] = None, + search_buffer_capacity: Optional[int] = None, + normalize_vector_distance: bool = False, + name: str = "redis_vector_search", + description: str = "Search for semantically similar documents using vector similarity with Redis.", + ): + """Initialize the vector search tool. + + Args: + index: The RedisVL SearchIndex to query. + vectorizer: The vectorizer for embedding queries. + vector_field_name: The name of the vector field in the index. + num_results: Default number of results to return (default: 10). + return_fields: Optional list of fields to return in results. + filter_expression: Optional RedisVL FilterExpression to narrow results. + dtype: The dtype of the vector (default: "float32"). + return_score: Whether to return the vector distance (default: True). + dialect: The RediSearch query dialect (default: 2). + sort_by: Field(s) to order results by. Can be str, tuple, or list. + in_order: Require query terms in same order as document (default: False). + hybrid_policy: Filter application policy - "BATCHES" or "ADHOC_BF". + batch_size: Batch size when hybrid_policy is "BATCHES". + ef_runtime: HNSW exploration factor at query time (higher = better recall). + epsilon: Range search approximation factor for HNSW/SVS-VAMANA indexes. + search_window_size: SVS-VAMANA search window size (higher = better recall). + use_search_history: SVS-VAMANA history mode - "OFF", "ON", or "AUTO". + search_buffer_capacity: SVS-VAMANA 2-level compression tuning parameter. + normalize_vector_distance: Convert distance to similarity score 0-1 (default: False). + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + """ + super().__init__( + name=name, + description=description, + index=index, + vectorizer=vectorizer, + return_fields=return_fields, + ) + self._vector_field_name = vector_field_name + self._num_results = num_results + self._filter_expression = filter_expression + self._dtype = dtype + self._return_score = return_score + self._dialect = dialect + self._sort_by = sort_by + self._in_order = in_order + self._hybrid_policy = hybrid_policy + self._batch_size = batch_size + self._ef_runtime = ef_runtime + self._epsilon = epsilon + self._search_window_size = search_window_size + self._use_search_history = use_search_history + self._search_buffer_capacity = search_buffer_capacity + self._normalize_vector_distance = normalize_vector_distance + + def _get_declaration(self) -> types.FunctionDeclaration: + """Get the function declaration for the LLM.""" + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The search query text.", + ), + "num_results": types.Schema( + type=types.Type.INTEGER, + description=( + "Number of results to return (default:" + f" {self._num_results})." + ), + ), + }, + required=["query"], + ), + ) + + def _build_query( + self, query_text: str, embedding: List[float], **kwargs: Any + ) -> VectorQuery: + """Build a VectorQuery for KNN search. + + Args: + query_text: The original query text (unused for vector search). + embedding: The vector embedding of the query text. + **kwargs: Additional parameters (e.g., num_results). + + Returns: + A VectorQuery configured for KNN search. + """ + num_results = kwargs.get("num_results", self._num_results) + + # Build query kwargs, only including optional params if set + query_kwargs: dict[str, Any] = { + "vector": embedding, + "vector_field_name": self._vector_field_name, + "num_results": num_results, + "return_fields": self._return_fields, + "filter_expression": self._filter_expression, + "dtype": self._dtype, + "return_score": self._return_score, + "dialect": self._dialect, + "sort_by": self._sort_by, + "in_order": self._in_order, + "normalize_vector_distance": self._normalize_vector_distance, + } + + # Add optional parameters only if set (for version compatibility) + if self._hybrid_policy is not None: + query_kwargs["hybrid_policy"] = self._hybrid_policy + if self._batch_size is not None: + query_kwargs["batch_size"] = self._batch_size + if self._ef_runtime is not None: + query_kwargs["ef_runtime"] = self._ef_runtime + if self._epsilon is not None: + query_kwargs["epsilon"] = self._epsilon + if self._search_window_size is not None: + query_kwargs["search_window_size"] = self._search_window_size + if self._use_search_history is not None: + query_kwargs["use_search_history"] = self._use_search_history + if self._search_buffer_capacity is not None: + query_kwargs["search_buffer_capacity"] = self._search_buffer_capacity + + return VectorQuery(**query_kwargs) diff --git a/tests/unittests/tools/__init__.py b/tests/unittests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/redis/__init__.py b/tests/unittests/tools/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py new file mode 100644 index 0000000..85fe1c9 --- /dev/null +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -0,0 +1,176 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisHybridSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.query import HybridQuery +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisHybridSearchTool + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + ) + return index + + +@pytest.fixture +def hybrid_search_tool(mock_index, mock_vectorizer): + """Create RedisHybridSearchTool instance for testing.""" + return RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + text_field_name="content", + num_results=5, + return_fields=["title", "content"], + ) + + +class TestRedisHybridSearchToolInit: + """Tests for RedisHybridSearchTool initialization.""" + + def test_default_parameters(self, mock_index, mock_vectorizer): + """Test default parameter values.""" + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert tool._text_field_name == "content" + assert tool._vector_field_name == "embedding" + assert tool._text_scorer == "BM25STD" + assert tool._combination_method is None + assert tool._linear_alpha == 0.3 + assert tool._rrf_window == 20 + assert tool._rrf_constant == 60 + assert tool._num_results == 10 + assert tool._filter_expression is None + assert tool._dtype == "float32" + assert tool._stopwords == "english" + + def test_custom_parameters(self, mock_index, mock_vectorizer): + """Test custom parameter values.""" + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + text_field_name="description", + vector_field_name="vec", + text_scorer="TFIDF", + combination_method="LINEAR", + linear_alpha=0.7, + rrf_window=30, + rrf_constant=80, + num_results=20, + return_fields=["title", "url"], + dtype="float64", + stopwords={"the", "a", "an"}, + ) + assert tool._text_field_name == "description" + assert tool._vector_field_name == "vec" + assert tool._text_scorer == "TFIDF" + assert tool._combination_method == "LINEAR" + assert tool._linear_alpha == 0.7 + assert tool._rrf_window == 30 + assert tool._rrf_constant == 80 + assert tool._num_results == 20 + assert tool._return_fields == ["title", "url"] + assert tool._dtype == "float64" + assert tool._stopwords == {"the", "a", "an"} + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_hybrid", + description="Custom hybrid search", + ) + assert tool.name == "custom_hybrid" + assert tool.description == "Custom hybrid search" + + +def _hybrid_query_available(): + """Check if HybridQuery dependencies are available.""" + try: + from redis.commands.search.hybrid_query import CombineResultsMethod + from redis.commands.search.hybrid_query import HybridPostProcessingConfig + + return True + except (ImportError, ModuleNotFoundError): + return False + + +class TestRedisHybridSearchToolBuildQuery: + """Tests for _build_query method.""" + + @pytest.mark.skipif( + not _hybrid_query_available(), + reason="HybridQuery requires redis-py>=7.1.0 and Redis>=8.4.0", + ) + def test_build_query_basic(self, hybrid_search_tool): + """Test basic query building.""" + embedding = [0.1] * 384 + query = hybrid_search_tool._build_query("test query", embedding) + + assert isinstance(query, HybridQuery) + + @pytest.mark.skipif( + not _hybrid_query_available(), + reason="HybridQuery requires redis-py>=7.1.0 and Redis>=8.4.0", + ) + def test_build_query_with_num_results_override(self, hybrid_search_tool): + """Test query building with num_results override.""" + embedding = [0.1] * 384 + query = hybrid_search_tool._build_query( + "test query", embedding, num_results=15 + ) + + assert query._num_results == 15 + + +class TestRedisHybridSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, hybrid_search_tool): + """Test function declaration generation.""" + declaration = hybrid_search_tool._get_declaration() + + assert declaration.name == "redis_hybrid_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required diff --git a/tests/unittests/tools/redis/test_range_search_tool.py b/tests/unittests/tools/redis/test_range_search_tool.py new file mode 100644 index 0000000..db0173c --- /dev/null +++ b/tests/unittests/tools/redis/test_range_search_tool.py @@ -0,0 +1,157 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisRangeSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.query import VectorRangeQuery +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisRangeSearchTool + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[{ + "title": "Test Doc", + "content": "Test content", + "vector_distance": 0.1, + }] + ) + return index + + +@pytest.fixture +def range_search_tool(mock_index, mock_vectorizer): + """Create RedisRangeSearchTool instance for testing.""" + return RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + distance_threshold=0.3, + num_results=5, + return_fields=["title", "content"], + ) + + +class TestRedisRangeSearchToolInit: + """Tests for RedisRangeSearchTool initialization.""" + + def test_default_parameters(self, mock_index, mock_vectorizer): + """Test default parameter values.""" + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert tool._vector_field_name == "embedding" + assert tool._distance_threshold == 0.2 + assert tool._num_results == 10 + assert tool._dtype == "float32" + assert tool._return_score is True + assert tool._dialect == 2 + assert tool._in_order is False + assert tool._normalize_vector_distance is False + assert tool._filter_expression is None + assert tool._sort_by is None + assert tool._epsilon is None + + def test_custom_parameters(self, mock_index, mock_vectorizer): + """Test custom parameter values.""" + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + vector_field_name="vec", + distance_threshold=0.5, + num_results=20, + return_fields=["title", "url"], + dtype="float64", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + epsilon=0.01, + ) + assert tool._vector_field_name == "vec" + assert tool._distance_threshold == 0.5 + assert tool._num_results == 20 + assert tool._return_fields == ["title", "url"] + assert tool._dtype == "float64" + assert tool._return_score is False + assert tool._dialect == 3 + assert tool._in_order is True + assert tool._normalize_vector_distance is True + assert tool._epsilon == 0.01 + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_range", + description="Custom range search", + ) + assert tool.name == "custom_range" + assert tool.description == "Custom range search" + + +class TestRedisRangeSearchToolBuildQuery: + """Tests for _build_query method.""" + + def test_build_query_basic(self, range_search_tool): + """Test basic query building.""" + embedding = [0.1] * 384 + query = range_search_tool._build_query("test query", embedding) + + assert isinstance(query, VectorRangeQuery) + + def test_build_query_with_threshold_override(self, range_search_tool): + """Test query building with distance_threshold override.""" + embedding = [0.1] * 384 + query = range_search_tool._build_query( + "test query", embedding, distance_threshold=0.8 + ) + + assert query._distance_threshold == 0.8 + + +class TestRedisRangeSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, range_search_tool): + """Test function declaration generation.""" + declaration = range_search_tool._get_declaration() + + assert declaration.name == "redis_range_search" + assert "query" in declaration.parameters.properties + assert "distance_threshold" in declaration.parameters.properties + assert "query" in declaration.parameters.required diff --git a/tests/unittests/tools/redis/test_text_search_tool.py b/tests/unittests/tools/redis/test_text_search_tool.py new file mode 100644 index 0000000..5144a70 --- /dev/null +++ b/tests/unittests/tools/redis/test_text_search_tool.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisTextSearchTool.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex + +from google.adk_community.tools.redis import RedisTextSearchTool + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.query = MagicMock( + return_value=[ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + ) + return index + + +@pytest.fixture +def text_search_tool(mock_index): + """Create RedisTextSearchTool instance for testing.""" + return RedisTextSearchTool( + index=mock_index, + text_field_name="content", + num_results=5, + return_fields=["title", "content"], + stopwords=None, # Avoid nltk dependency in tests + ) + + +class TestRedisTextSearchToolInit: + """Tests for RedisTextSearchTool initialization.""" + + def test_default_parameters(self, mock_index): + """Test default parameter values.""" + tool = RedisTextSearchTool(index=mock_index) + assert tool._text_field_name == "content" + assert tool._text_scorer == "BM25STD" + assert tool._num_results == 10 + assert tool._return_score is True + assert tool._dialect == 2 + assert tool._in_order is False + assert tool._stopwords == "english" + assert tool._filter_expression is None + assert tool._sort_by is None + assert tool._return_fields is None + + def test_custom_parameters(self, mock_index): + """Test custom parameter values.""" + tool = RedisTextSearchTool( + index=mock_index, + text_field_name="description", + text_scorer="TFIDF", + num_results=20, + return_fields=["title", "url"], + return_score=False, + dialect=3, + in_order=True, + stopwords={"the", "a", "an"}, + ) + assert tool._text_field_name == "description" + assert tool._text_scorer == "TFIDF" + assert tool._num_results == 20 + assert tool._return_fields == ["title", "url"] + assert tool._return_score is False + assert tool._dialect == 3 + assert tool._in_order is True + assert tool._stopwords == {"the", "a", "an"} + + def test_custom_name_and_description(self, mock_index): + """Test custom tool name and description.""" + tool = RedisTextSearchTool( + index=mock_index, + name="custom_text", + description="Custom text search", + ) + assert tool.name == "custom_text" + assert tool.description == "Custom text search" + + def test_no_vectorizer_required(self, mock_index): + """Test that TextSearchTool does not require a vectorizer.""" + # This should work without any vectorizer + tool = RedisTextSearchTool(index=mock_index) + assert not hasattr(tool, "_vectorizer") + + +class TestRedisTextSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, text_search_tool): + """Test function declaration generation.""" + declaration = text_search_tool._get_declaration() + + assert declaration.name == "redis_text_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required + + +class TestRedisTextSearchToolRunAsync: + """Tests for run_async method.""" + + @pytest.mark.asyncio + @patch("google.adk_community.tools.redis.text_search_tool.asyncio.to_thread") + async def test_run_async_success( + self, mock_to_thread, text_search_tool, mock_index + ): + """Test successful search execution.""" + mock_to_thread.return_value = [ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + mock_context = MagicMock() + result = await text_search_tool.run_async( + args={"query": "test query"}, + tool_context=mock_context, + ) + + assert result["status"] == "success" + assert result["count"] == 1 + assert len(result["results"]) == 1 + mock_to_thread.assert_called_once() + + @pytest.mark.asyncio + async def test_run_async_empty_query(self, text_search_tool): + """Test error handling for empty query.""" + mock_context = MagicMock() + result = await text_search_tool.run_async( + args={"query": ""}, + tool_context=mock_context, + ) + + assert result["status"] == "error" + assert "required" in result["error"].lower() + + @pytest.mark.asyncio + @patch("google.adk_community.tools.redis.text_search_tool.asyncio.to_thread") + async def test_run_async_with_num_results( + self, mock_to_thread, text_search_tool, mock_index + ): + """Test search with custom num_results.""" + mock_to_thread.return_value = [ + {"title": "Test Doc", "content": "Test content", "score": 0.9} + ] + mock_context = MagicMock() + await text_search_tool.run_async( + args={"query": "test", "num_results": 15}, + tool_context=mock_context, + ) + + # Verify to_thread was called + mock_to_thread.assert_called_once() diff --git a/tests/unittests/tools/redis/test_vector_search_tool.py b/tests/unittests/tools/redis/test_vector_search_tool.py new file mode 100644 index 0000000..a3ad514 --- /dev/null +++ b/tests/unittests/tools/redis/test_vector_search_tool.py @@ -0,0 +1,274 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RedisVectorSearchTool.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +# Skip all tests if redisvl is not installed +pytest.importorskip("redisvl") + +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.utils.vectorize import BaseVectorizer + +from google.adk_community.tools.redis import RedisVectorSearchTool + + +@pytest.fixture +def mock_vectorizer(): + """Mock RedisVL vectorizer.""" + vectorizer = MagicMock(spec=BaseVectorizer) + vectorizer.embed = MagicMock(return_value=[0.1] * 384) + vectorizer.aembed = AsyncMock(return_value=[0.1] * 384) + return vectorizer + + +@pytest.fixture +def mock_index(): + """Mock RedisVL SearchIndex.""" + index = MagicMock(spec=SearchIndex) + index.search = MagicMock( + return_value=[{ + "title": "Test Doc", + "content": "Test content", + "vector_distance": 0.1, + }] + ) + return index + + +@pytest.fixture +def vector_search_tool(mock_index, mock_vectorizer): + """Create RedisVectorSearchTool instance for testing.""" + return RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + num_results=5, + return_fields=["title", "content"], + ) + + +class TestRedisVectorSearchToolInit: + """Tests for RedisVectorSearchTool initialization.""" + + def test_default_parameters(self, mock_index, mock_vectorizer): + """Test default parameter values.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert tool._vector_field_name == "embedding" + assert tool._num_results == 10 + assert tool._dtype == "float32" + assert tool._return_score is True + assert tool._dialect == 2 + assert tool._in_order is False + assert tool._normalize_vector_distance is False + assert tool._filter_expression is None + assert tool._sort_by is None + assert tool._hybrid_policy is None + assert tool._batch_size is None + assert tool._ef_runtime is None + assert tool._epsilon is None + assert tool._search_window_size is None + assert tool._use_search_history is None + assert tool._search_buffer_capacity is None + + def test_custom_parameters(self, mock_index, mock_vectorizer): + """Test custom parameter values.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + vector_field_name="custom_embedding", + num_results=20, + return_fields=["title", "content", "url"], + dtype="float64", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + hybrid_policy="BATCHES", + batch_size=100, + ef_runtime=200, + epsilon=0.01, + search_window_size=50, + use_search_history="ON", + search_buffer_capacity=1000, + ) + assert tool._vector_field_name == "custom_embedding" + assert tool._num_results == 20 + assert tool._return_fields == ["title", "content", "url"] + assert tool._dtype == "float64" + assert tool._return_score is False + assert tool._dialect == 3 + assert tool._in_order is True + assert tool._normalize_vector_distance is True + assert tool._hybrid_policy == "BATCHES" + assert tool._batch_size == 100 + assert tool._ef_runtime == 200 + assert tool._epsilon == 0.01 + assert tool._search_window_size == 50 + assert tool._use_search_history == "ON" + assert tool._search_buffer_capacity == 1000 + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_search", + description="Custom search description", + ) + assert tool.name == "custom_search" + assert tool.description == "Custom search description" + + +class TestRedisVectorSearchToolBuildQuery: + """Tests for _build_query method.""" + + def test_build_query_basic(self, vector_search_tool): + """Test basic query building.""" + embedding = [0.1] * 384 + query = vector_search_tool._build_query("test query", embedding) + + assert isinstance(query, VectorQuery) + # VectorQuery uses private attributes + assert query._vector == embedding + assert query._vector_field_name == "embedding" + assert query._num_results == 5 + + def test_build_query_with_num_results_override(self, vector_search_tool): + """Test query building with num_results override.""" + embedding = [0.1] * 384 + query = vector_search_tool._build_query( + "test query", embedding, num_results=15 + ) + + assert query._num_results == 15 + + def test_build_query_passes_core_parameters( + self, mock_index, mock_vectorizer + ): + """Test that core VectorQuery parameters are passed through.""" + # Use a string filter expression (valid type for VectorQuery) + filter_str = "@category:{redis}" + + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + vector_field_name="vec", + num_results=10, + return_fields=["title"], + filter_expression=filter_str, + dtype="float16", + return_score=False, + dialect=3, + in_order=True, + normalize_vector_distance=True, + ) + + embedding = [0.1] * 384 + query = tool._build_query("test", embedding) + + # Check core attributes that are exposed on VectorQuery + assert query._vector_field_name == "vec" + assert query._num_results == 10 + assert query._return_fields == ["title"] + assert query._dtype == "float16" + assert query._dialect == 3 + assert query._in_order is True + + def test_stores_optional_parameters(self, mock_index, mock_vectorizer): + """Test that optional parameters are stored correctly.""" + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + hybrid_policy="ADHOC_BF", + batch_size=50, + ef_runtime=100, + epsilon=0.05, + search_window_size=25, + use_search_history="AUTO", + search_buffer_capacity=500, + ) + + # Verify parameters are stored on the tool + assert tool._hybrid_policy == "ADHOC_BF" + assert tool._batch_size == 50 + assert tool._ef_runtime == 100 + assert tool._epsilon == 0.05 + assert tool._search_window_size == 25 + assert tool._use_search_history == "AUTO" + assert tool._search_buffer_capacity == 500 + + +class TestRedisVectorSearchToolDeclaration: + """Tests for _get_declaration method.""" + + def test_get_declaration(self, vector_search_tool): + """Test function declaration generation.""" + declaration = vector_search_tool._get_declaration() + + assert declaration.name == "redis_vector_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required + + +class TestRedisVectorSearchToolImportError: + """Tests for import error handling.""" + + def test_import_error_message(self): + """Test that helpful error message is shown when redisvl not installed.""" + with patch.dict("sys.modules", {"redisvl": None}): + # This test verifies the error message format + # The actual import error is tested by the try/except in __init__.py + pass + + +class TestRedisVectorSearchToolFilterExpression: + """Tests for filter expression handling.""" + + def test_filter_expression_stored_on_tool(self, mock_index, mock_vectorizer): + """Test that filter expression is stored on the tool.""" + filter_str = "@category:{redis}" + + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + filter_expression=filter_str, + ) + + assert tool._filter_expression == filter_str + + def test_filter_expression_passed_to_query(self, mock_index, mock_vectorizer): + """Test that filter expression is correctly passed to VectorQuery.""" + filter_str = "@category:{redis}" + + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + filter_expression=filter_str, + ) + + embedding = [0.1] * 384 + query = tool._build_query("test", embedding) + + # VectorQuery wraps string filters in FilterExpression + assert str(query._filter_expression) == filter_str From e3ecb52957d5be4d42ebf718301f49f23d42dc21 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 7 Jan 2026 19:49:07 +0200 Subject: [PATCH 02/16] chore: update redisvl dependency to >=0.13.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fb43714..6172336 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ test = [ "pytest-asyncio>=1.2.0", ] redis-vl = [ - "redisvl>=0.4.0", + "redisvl>=0.13.2", "nltk>=3.8.0", "sentence-transformers>=2.2.0", ] From c427907b09f50ed900f738e9f3f07d940d918609 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 7 Jan 2026 19:49:15 +0200 Subject: [PATCH 03/16] feat(redis): add RedisVL 0.13.2 HybridQuery parameters - Add vector_param_name, yield_text_score_as, vector_search_method - Add knn_ef_runtime, range_radius, range_epsilon, yield_vsim_score_as - Add yield_combined_score_as for score output control - Document Redis >= 8.4.0 and redis-py >= 7.1.0 requirements --- .../tools/redis/hybrid_search_tool.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index 50492ff..2b1bbdd 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -36,9 +36,13 @@ class RedisHybridSearchTool(BaseRedisSearchTool): """Hybrid search tool combining vector similarity and BM25 text search. This tool performs a hybrid search that combines semantic vector similarity - with keyword-based BM25 text matching. This is useful when you want to - leverage both the semantic understanding of embeddings and the precision - of keyword matching. + with keyword-based BM25 text matching using Redis's native FT.HYBRID command. + This is useful when you want to leverage both the semantic understanding of + embeddings and the precision of keyword matching. + + Requirements: + - Redis >= 8.4.0 (for native FT.HYBRID command support) + - redis-py >= 7.1.0 Example: ```python @@ -69,11 +73,19 @@ def __init__( vectorizer: BaseVectorizer, text_field_name: str = "content", vector_field_name: str = "embedding", + vector_param_name: str = "vector", text_scorer: str = "BM25STD", + yield_text_score_as: Optional[str] = None, + vector_search_method: Optional[str] = None, + knn_ef_runtime: int = 10, + range_radius: Optional[float] = None, + range_epsilon: float = 0.01, + yield_vsim_score_as: Optional[str] = None, combination_method: Optional[str] = None, linear_alpha: float = 0.3, rrf_window: int = 20, rrf_constant: int = 60, + yield_combined_score_as: Optional[str] = None, num_results: int = 10, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, @@ -90,11 +102,19 @@ def __init__( vectorizer: The vectorizer for embedding queries. text_field_name: The name of the text field for BM25 search. vector_field_name: The name of the vector field for similarity search. + vector_param_name: Name of the parameter substitution for vector blob. text_scorer: The text scoring algorithm (default: "BM25STD"). + yield_text_score_as: Field name to yield the text score as. + vector_search_method: Vector search method - "KNN" or "RANGE". + knn_ef_runtime: Exploration factor for HNSW when using KNN (default: 10). + range_radius: Search radius when using RANGE vector search. + range_epsilon: Epsilon for RANGE search accuracy (default: 0.01). + yield_vsim_score_as: Field name to yield the vector similarity score as. combination_method: Score combination method - "RRF" or "LINEAR". linear_alpha: Weight of text score when using LINEAR (default: 0.3). rrf_window: Window size for RRF combination (default: 20). rrf_constant: Constant for RRF combination (default: 60). + yield_combined_score_as: Field name to yield the combined score as. num_results: Default number of results to return (default: 10). return_fields: Optional list of fields to return in results. filter_expression: Optional filter expression to narrow results. @@ -113,11 +133,19 @@ def __init__( ) self._text_field_name = text_field_name self._vector_field_name = vector_field_name + self._vector_param_name = vector_param_name self._text_scorer = text_scorer + self._yield_text_score_as = yield_text_score_as + self._vector_search_method = vector_search_method + self._knn_ef_runtime = knn_ef_runtime + self._range_radius = range_radius + self._range_epsilon = range_epsilon + self._yield_vsim_score_as = yield_vsim_score_as self._combination_method = combination_method self._linear_alpha = linear_alpha self._rrf_window = rrf_window self._rrf_constant = rrf_constant + self._yield_combined_score_as = yield_combined_score_as self._num_results = num_results self._filter_expression = filter_expression self._dtype = dtype @@ -168,12 +196,20 @@ def _build_query( text_field_name=self._text_field_name, vector=embedding, vector_field_name=self._vector_field_name, + vector_param_name=self._vector_param_name, text_scorer=self._text_scorer, + yield_text_score_as=self._yield_text_score_as, + vector_search_method=self._vector_search_method, + knn_ef_runtime=self._knn_ef_runtime, + range_radius=self._range_radius, + range_epsilon=self._range_epsilon, + yield_vsim_score_as=self._yield_vsim_score_as, + filter_expression=self._filter_expression, combination_method=self._combination_method, - linear_alpha=self._linear_alpha, rrf_window=self._rrf_window, rrf_constant=self._rrf_constant, - filter_expression=self._filter_expression, + linear_alpha=self._linear_alpha, + yield_combined_score_as=self._yield_combined_score_as, dtype=self._dtype, num_results=num_results, return_fields=self._return_fields, From baa5ff5d1919a14bfd72ec21b616f0ddfbe6c12a Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 7 Jan 2026 19:49:20 +0200 Subject: [PATCH 04/16] docs: update GitHub issue with redisvl>=0.13.2 requirement --- github_issue_redis_search_tools.md | 92 ++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 github_issue_redis_search_tools.md diff --git a/github_issue_redis_search_tools.md b/github_issue_redis_search_tools.md new file mode 100644 index 0000000..ac6385d --- /dev/null +++ b/github_issue_redis_search_tools.md @@ -0,0 +1,92 @@ +# [Proposal] Add RedisVL Search Tools for Knowledge Base Retrieval + +## Is your feature request related to a problem? Please describe. + +ADK has retrieval tools for Google Cloud (Vertex AI Search, Discovery Engine), but **no self-hosted option** for developers who need full control over their infrastructure. RAG is a fundamental pattern for building useful agents, and Redis is already running in most organizations, these tools let developers add vector search, keyword search, and hybrid retrieval to their existing Redis infrastructure. [RedisVL](https://github.com/redis/redis-vl-python) is the official Redis Vector Search library that's also well adopted. + +Moreover, the repo has `RedisSessionService` for session persistence, but **there are no tools for agents to search Redis-based knowledge bases** without building it themselves. + +## Describe the solution you'd like + +Add `redisvl` as an optional dependency with four search tools that wrap RedisVL's query capabilities as ADK `BaseTool` implementations. + +**pyproject.toml change:** +```toml +[project.optional-dependencies] +redis-vl = [ + "redisvl>=0.13.2", + "nltk>=3.8.0", + "sentence-transformers>=2.2.0", +] +``` + +**Installation:** +```bash +# Existing functionality unchanged +pip install google-adk-community + +# Opt-in to vector search capabilities +pip install google-adk-community[redis-vl] +``` + +This aligns with the community repository's stated philosophy: + +> "This approach allows the **core ADK to remain stable and lightweight**, while giving the community the freedom to build and share powerful extensions." + +### Tools Provided + +| Tool | Search Type | Use Case | +|------|-------------|----------| +| `RedisVectorSearchTool` | KNN vector similarity | Semantic/conceptual queries | +| `RedisTextSearchTool` | BM25 full-text | Exact terms, acronyms, API names | +| `RedisHybridSearchTool` | Vector + BM25 combined | Best of both worlds | +| `RedisRangeSearchTool` | Distance threshold | Exhaustive retrieval, quality filtering | + +### Developer Experience + +```python +from google.adk import Agent +from google.adk_community.tools.redis import RedisVectorSearchTool +from redisvl.index import SearchIndex +from redisvl.utils.vectorize import HFTextVectorizer + +index = SearchIndex.from_yaml("schema.yaml") +index.connect("redis://localhost:6379") +vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + +tool = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + num_results=5, + return_fields=["title", "content", "url"], +) + +agent = Agent(model="gemini-2.0-flash", tools=[tool]) +``` + +### Common Features Across All Tools + +- **Filtering**: Tag, numeric, and geo filters via `filter_expression` +- **Field selection**: Control returned fields via `return_fields` +- **Async support**: Works with both `SearchIndex` and `AsyncSearchIndex` +- **Score normalization**: Convert distances to 0-1 similarity via `normalize_vector_distance=True` +- **Full parameter exposure**: All RedisVL query parameters are configurable + +## Describe alternatives you've considered + +1. **Implement vector search with raw Redis commands** — Would duplicate existing, maintained code in RedisVL. Users would get a degraded experience compared to using RedisVL directly. + +2. **Require users to install RedisVL separately** — Creates friction and doesn't provide ADK-native abstractions like `BaseTool` wrappers with proper function declarations for LLMs. + +3. **Use a different vector database** — Redis is already widely deployed in many organizations. Adding vector search to existing Redis infrastructure is lower friction than adopting a new database. + + + +## Why RedisVL? + +[RedisVL](https://github.com/redis/redis-vl-python) is the official Redis vector library (~50MB footprint). It provides: +- Schema-driven index management +- Multiple query types (vector, text, hybrid, range) +- Built-in vectorizers (HuggingFace, OpenAI, Cohere, etc.) +- Both sync and async APIs + From c4d32471da86591d3a071263feb59e602e9f2a3c Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 8 Jan 2026 14:44:18 +0200 Subject: [PATCH 05/16] fix: update default model to gemini-2.5-flash Address PR #43 code review feedback from Gemini Code Assist. Update model name from gemini-2.0-flash-001 to gemini-2.5-flash in: - Sample agent (agent.py) - Tool docstrings (vector, text, hybrid, range search tools) - GitHub issue document --- .../samples/redis_vl_search/redis_vl_search_agent/agent.py | 2 +- github_issue_redis_search_tools.md | 2 +- src/google/adk_community/tools/redis/hybrid_search_tool.py | 2 +- src/google/adk_community/tools/redis/range_search_tool.py | 2 +- src/google/adk_community/tools/redis/text_search_tool.py | 2 +- src/google/adk_community/tools/redis/vector_search_tool.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py index dc508c2..8a03f2f 100644 --- a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py @@ -135,7 +135,7 @@ def create_agent() -> Agent: tools = get_search_tools(index, vectorizer) return Agent( - model="gemini-2.0-flash-001", + model="gemini-2.5-flash", name="root_agent", instruction=INSTRUCTION, tools=tools, diff --git a/github_issue_redis_search_tools.md b/github_issue_redis_search_tools.md index ac6385d..aafe801 100644 --- a/github_issue_redis_search_tools.md +++ b/github_issue_redis_search_tools.md @@ -61,7 +61,7 @@ tool = RedisVectorSearchTool( return_fields=["title", "content", "url"], ) -agent = Agent(model="gemini-2.0-flash", tools=[tool]) +agent = Agent(model="gemini-2.5-flash", tools=[tool]) ``` ### Common Features Across All Tools diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index 2b1bbdd..011df4f 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -62,7 +62,7 @@ class RedisHybridSearchTool(BaseRedisSearchTool): return_fields=["title", "content"], ) - agent = Agent(model="gemini-2.0-flash", tools=[tool]) + agent = Agent(model="gemini-2.5-flash", tools=[tool]) ``` """ diff --git a/src/google/adk_community/tools/redis/range_search_tool.py b/src/google/adk_community/tools/redis/range_search_tool.py index cebe512..bb55bdc 100644 --- a/src/google/adk_community/tools/redis/range_search_tool.py +++ b/src/google/adk_community/tools/redis/range_search_tool.py @@ -60,7 +60,7 @@ class RedisRangeSearchTool(BaseRedisSearchTool): return_fields=["title", "content"], ) - agent = Agent(model="gemini-2.0-flash", tools=[tool]) + agent = Agent(model="gemini-2.5-flash", tools=[tool]) ``` """ diff --git a/src/google/adk_community/tools/redis/text_search_tool.py b/src/google/adk_community/tools/redis/text_search_tool.py index 4746aa9..fe5ac9b 100644 --- a/src/google/adk_community/tools/redis/text_search_tool.py +++ b/src/google/adk_community/tools/redis/text_search_tool.py @@ -59,7 +59,7 @@ class RedisTextSearchTool(BaseTool): return_fields=["title", "content"], ) - agent = Agent(model="gemini-2.0-flash", tools=[tool]) + agent = Agent(model="gemini-2.5-flash", tools=[tool]) ``` """ diff --git a/src/google/adk_community/tools/redis/vector_search_tool.py b/src/google/adk_community/tools/redis/vector_search_tool.py index 860bdff..5a6b1c9 100644 --- a/src/google/adk_community/tools/redis/vector_search_tool.py +++ b/src/google/adk_community/tools/redis/vector_search_tool.py @@ -62,7 +62,7 @@ class RedisVectorSearchTool(BaseRedisSearchTool): ) # Use with an agent - agent = Agent(model="gemini-2.0-flash", tools=[tool]) + agent = Agent(model="gemini-2.5-flash", tools=[tool]) ``` """ From cd22ab408501abc18f564858b7d4343f55ea69c5 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 8 Jan 2026 16:40:19 +0200 Subject: [PATCH 06/16] refactor: introduce two-tier hierarchy for Redis search tools - Split BaseRedisSearchTool into shared base + VectorizedSearchTool - BaseRedisSearchTool now contains shared logic (_execute_query, _run_search) - VectorizedSearchTool requires vectorizer (no longer Optional) - RedisTextSearchTool now extends BaseRedisSearchTool (inherits shared logic) - Vector-based tools extend VectorizedSearchTool This addresses PR review feedback: - 2b: Vectorizer is now required in VectorizedSearchTool (fail-fast) - 2c: Eliminates ~14 lines of duplicated code via inheritance --- src/google/adk_community/tools/__init__.py | 4 + .../adk_community/tools/redis/__init__.py | 2 + .../tools/redis/base_search_tool.py | 158 +++++++++++++----- .../tools/redis/hybrid_search_tool.py | 4 +- .../tools/redis/range_search_tool.py | 4 +- .../tools/redis/text_search_tool.py | 44 ++--- .../tools/redis/vector_search_tool.py | 4 +- 7 files changed, 139 insertions(+), 81 deletions(-) diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py index f737b7a..3c667ed 100644 --- a/src/google/adk_community/tools/__init__.py +++ b/src/google/adk_community/tools/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "BaseRedisSearchTool", + "VectorizedSearchTool", "RedisVectorSearchTool", "RedisHybridSearchTool", "RedisRangeSearchTool", @@ -29,6 +30,7 @@ # Redis tool names for lazy loading _REDIS_TOOLS = { "BaseRedisSearchTool", + "VectorizedSearchTool", "RedisVectorSearchTool", "RedisHybridSearchTool", "RedisRangeSearchTool", @@ -45,9 +47,11 @@ def __getattr__(name: str): from .redis import RedisRangeSearchTool from .redis import RedisTextSearchTool from .redis import RedisVectorSearchTool + from .redis import VectorizedSearchTool globals().update({ "BaseRedisSearchTool": BaseRedisSearchTool, + "VectorizedSearchTool": VectorizedSearchTool, "RedisVectorSearchTool": RedisVectorSearchTool, "RedisHybridSearchTool": RedisHybridSearchTool, "RedisRangeSearchTool": RedisRangeSearchTool, diff --git a/src/google/adk_community/tools/redis/__init__.py b/src/google/adk_community/tools/redis/__init__.py index 027592f..45b8e8f 100644 --- a/src/google/adk_community/tools/redis/__init__.py +++ b/src/google/adk_community/tools/redis/__init__.py @@ -40,6 +40,7 @@ try: from .base_search_tool import BaseRedisSearchTool + from .base_search_tool import VectorizedSearchTool from .hybrid_search_tool import RedisHybridSearchTool from .range_search_tool import RedisRangeSearchTool from .text_search_tool import RedisTextSearchTool @@ -52,6 +53,7 @@ __all__ = [ "BaseRedisSearchTool", + "VectorizedSearchTool", "RedisVectorSearchTool", "RedisHybridSearchTool", "RedisRangeSearchTool", diff --git a/src/google/adk_community/tools/redis/base_search_tool.py b/src/google/adk_community/tools/redis/base_search_tool.py index 24d303f..bdb2170 100644 --- a/src/google/adk_community/tools/redis/base_search_tool.py +++ b/src/google/adk_community/tools/redis/base_search_tool.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Base class for Redis search tools using RedisVL.""" +"""Base classes for Redis search tools using RedisVL.""" from __future__ import annotations from abc import abstractmethod import asyncio from typing import Any +from typing import Callable +from typing import Coroutine from typing import Dict from typing import List from typing import Optional @@ -33,15 +35,16 @@ class BaseRedisSearchTool(BaseTool): - """Base class for Redis search tools using RedisVL. + """Base class for ALL Redis search tools using RedisVL. - This class provides common functionality for all Redis search tools: - - Index and vectorizer management - - Common error handling - - Standard response format + This class provides common functionality shared by all Redis search tools: + - Index management (sync and async) + - Query execution with proper async handling + - Standard response formatting + - Error handling - Subclasses must implement `_build_query()` to create the appropriate - RedisVL query object for their search type. + Subclasses should use `_run_search` to execute queries with consistent + error handling and response formatting. """ def __init__( @@ -50,7 +53,6 @@ def __init__( name: str, description: str, index: Union[SearchIndex, AsyncSearchIndex], - vectorizer: Optional[BaseVectorizer] = None, return_fields: Optional[List[str]] = None, ): """Initialize the base Redis search tool. @@ -59,12 +61,10 @@ def __init__( name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). index: The RedisVL SearchIndex or AsyncSearchIndex to query. - vectorizer: Optional vectorizer for embedding queries. return_fields: Optional list of fields to return in results. """ super().__init__(name=name, description=description) self._index = index - self._vectorizer = vectorizer self._return_fields = return_fields self._is_async_index = isinstance(index, AsyncSearchIndex) @@ -89,6 +89,104 @@ def _get_declaration(self) -> types.FunctionDeclaration: ), ) + async def _execute_query(self, query: Any) -> List[Dict[str, Any]]: + """Execute a RedisVL query and return formatted results. + + Args: + query: A RedisVL query object (VectorQuery, TextQuery, etc.) + + Returns: + List of result dictionaries. + """ + if self._is_async_index: + results = await self._index.query(query) + else: + # Run sync query in thread pool to avoid blocking + results = await asyncio.to_thread(self._index.query, query) + + return [dict(r) for r in results] if results else [] + + async def _run_search( + self, + args: Dict[str, Any], + build_query_fn: Callable[[str, Dict[str, Any]], Coroutine[Any, Any, Any]], + ) -> Dict[str, Any]: + """Execute a search with consistent validation, execution, and formatting. + + This is a template method that handles: + - Query text validation + - Query building (via the provided async function) + - Query execution + - Response formatting + - Error handling + + Args: + args: Arguments from the LLM, must include 'query'. + build_query_fn: Async function that takes (query_text, args) and + returns a RedisVL query object. + + Returns: + A dictionary with status, count, and results (or error). + """ + query_text = args.get("query", "") + + if not query_text: + return {"status": "error", "error": "Query text is required."} + + try: + # Build the query using the provided function + redisvl_query = await build_query_fn(query_text, args) + + # Execute and format results + results = await self._execute_query(redisvl_query) + + return { + "status": "success", + "count": len(results), + "results": results, + } + + except Exception as e: + return {"status": "error", "error": str(e)} + + +class VectorizedSearchTool(BaseRedisSearchTool): + """Base class for Redis search tools that require vector embeddings. + + This class extends BaseRedisSearchTool with: + - Required vectorizer for embedding queries + - Abstract _build_query method for subclasses to implement + + Use this as the base class for vector-based search tools like + VectorSearchTool, HybridSearchTool, and RangeSearchTool. + """ + + def __init__( + self, + *, + name: str, + description: str, + index: Union[SearchIndex, AsyncSearchIndex], + vectorizer: BaseVectorizer, + return_fields: Optional[List[str]] = None, + ): + """Initialize the vectorized search tool. + + Args: + name: The name of the tool (exposed to LLM). + description: The description of the tool (exposed to LLM). + index: The RedisVL SearchIndex or AsyncSearchIndex to query. + vectorizer: The vectorizer for embedding queries (required). + return_fields: Optional list of fields to return in results. + """ + super().__init__( + name=name, + description=description, + index=index, + return_fields=return_fields, + ) + self._vectorizer = vectorizer + @abstractmethod def _build_query( self, query_text: str, embedding: List[float], **kwargs: Any @@ -108,7 +206,7 @@ def _build_query( async def run_async( self, *, args: Dict[str, Any], tool_context: ToolContext ) -> Dict[str, Any]: - """Execute the search query. + """Execute the vector-based search query. Args: args: Arguments from the LLM, must include 'query'. @@ -117,39 +215,9 @@ async def run_async( Returns: A dictionary with status, count, and results. """ - query_text = args.get("query", "") - - if not query_text: - return {"status": "error", "error": "Query text is required."} - - try: - # Embed the query text - if self._vectorizer is None: - return { - "status": "error", - "error": "Vectorizer is required for this search type.", - } + async def build_query_fn(query_text: str, args: Dict[str, Any]) -> Any: embedding = await self._vectorizer.aembed(query_text) + return self._build_query(query_text, embedding, **args) - # Build the query (subclass-specific) - redisvl_query = self._build_query(query_text, embedding, **args) - - # Execute the query - handle both sync and async indexes - if self._is_async_index: - results = await self._index.query(redisvl_query) - else: - # Run sync query in thread pool to avoid blocking - results = await asyncio.to_thread(self._index.query, redisvl_query) - - # Format results - formatted_results = [dict(r) for r in results] if results else [] - - return { - "status": "success", - "count": len(formatted_results), - "results": formatted_results, - } - - except Exception as e: - return {"status": "error", "error": str(e)} + return await self._run_search(args, build_query_fn) diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index 011df4f..df395f7 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -29,10 +29,10 @@ from redisvl.query import HybridQuery from redisvl.utils.vectorize import BaseVectorizer -from .base_search_tool import BaseRedisSearchTool +from .base_search_tool import VectorizedSearchTool -class RedisHybridSearchTool(BaseRedisSearchTool): +class RedisHybridSearchTool(VectorizedSearchTool): """Hybrid search tool combining vector similarity and BM25 text search. This tool performs a hybrid search that combines semantic vector similarity diff --git a/src/google/adk_community/tools/redis/range_search_tool.py b/src/google/adk_community/tools/redis/range_search_tool.py index bb55bdc..179c8a6 100644 --- a/src/google/adk_community/tools/redis/range_search_tool.py +++ b/src/google/adk_community/tools/redis/range_search_tool.py @@ -28,7 +28,7 @@ from redisvl.query import VectorRangeQuery from redisvl.utils.vectorize import BaseVectorizer -from .base_search_tool import BaseRedisSearchTool +from .base_search_tool import VectorizedSearchTool # Type alias for sort specification SortSpec = Optional[ @@ -36,7 +36,7 @@ ] -class RedisRangeSearchTool(BaseRedisSearchTool): +class RedisRangeSearchTool(VectorizedSearchTool): """Vector range search tool using distance threshold. This tool finds all documents within a specified distance threshold diff --git a/src/google/adk_community/tools/redis/text_search_tool.py b/src/google/adk_community/tools/redis/text_search_tool.py index fe5ac9b..ba7c24f 100644 --- a/src/google/adk_community/tools/redis/text_search_tool.py +++ b/src/google/adk_community/tools/redis/text_search_tool.py @@ -16,7 +16,6 @@ from __future__ import annotations -import asyncio from typing import Any from typing import Dict from typing import List @@ -25,20 +24,21 @@ from typing import Tuple from typing import Union -from google.adk.tools import BaseTool from google.adk.tools.tool_context import ToolContext from google.genai import types from redisvl.index import AsyncSearchIndex from redisvl.index import SearchIndex from redisvl.query import TextQuery +from .base_search_tool import BaseRedisSearchTool + # Type alias for sort specification SortSpec = Optional[ Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] ] -class RedisTextSearchTool(BaseTool): +class RedisTextSearchTool(BaseRedisSearchTool): """Full-text search tool using BM25 scoring. This tool performs keyword-based full-text search using BM25 scoring. @@ -97,19 +97,21 @@ def __init__( name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). """ - super().__init__(name=name, description=description) - self._index = index + super().__init__( + name=name, + description=description, + index=index, + return_fields=return_fields, + ) self._text_field_name = text_field_name self._text_scorer = text_scorer self._num_results = num_results - self._return_fields = return_fields self._filter_expression = filter_expression self._return_score = return_score self._dialect = dialect self._sort_by = sort_by self._in_order = in_order self._stopwords = stopwords - self._is_async_index = isinstance(index, AsyncSearchIndex) def _get_declaration(self) -> types.FunctionDeclaration: """Get the function declaration for the LLM.""" @@ -147,15 +149,12 @@ async def run_async( Returns: A dictionary with status, count, and results. """ - query_text = args.get("query", "") - if not query_text: - return {"status": "error", "error": "Query text is required."} - - try: + async def build_query_fn( + query_text: str, args: Dict[str, Any] + ) -> TextQuery: num_results = args.get("num_results", self._num_results) - - text_query = TextQuery( + return TextQuery( text=query_text, text_field_name=self._text_field_name, text_scorer=self._text_scorer, @@ -169,19 +168,4 @@ async def run_async( stopwords=self._stopwords, ) - # Execute the query - handle both sync and async indexes - if self._is_async_index: - results = await self._index.query(text_query) - else: - results = await asyncio.to_thread(self._index.query, text_query) - - formatted_results = [dict(r) for r in results] if results else [] - - return { - "status": "success", - "count": len(formatted_results), - "results": formatted_results, - } - - except Exception as e: - return {"status": "error", "error": str(e)} + return await self._run_search(args, build_query_fn) diff --git a/src/google/adk_community/tools/redis/vector_search_tool.py b/src/google/adk_community/tools/redis/vector_search_tool.py index 5a6b1c9..f9471d1 100644 --- a/src/google/adk_community/tools/redis/vector_search_tool.py +++ b/src/google/adk_community/tools/redis/vector_search_tool.py @@ -28,7 +28,7 @@ from redisvl.query import VectorQuery from redisvl.utils.vectorize import BaseVectorizer -from .base_search_tool import BaseRedisSearchTool +from .base_search_tool import VectorizedSearchTool # Type alias for sort specification SortSpec = Optional[ @@ -36,7 +36,7 @@ ] -class RedisVectorSearchTool(BaseRedisSearchTool): +class RedisVectorSearchTool(VectorizedSearchTool): """Vector similarity search tool using RedisVL. This tool performs K-nearest neighbor (KNN) vector similarity search From 8618ceb6edfebe8787689d47574eb8468c882793 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 8 Jan 2026 16:42:30 +0200 Subject: [PATCH 07/16] style: run autoformat (isort + pyink) --- src/google/adk_community/__init__.py | 1 + src/google/adk_community/memory/__init__.py | 1 - .../memory/open_memory_service.py | 108 +- src/google/adk_community/memory/utils.py | 5 +- .../sessions/redis_session_service.py | 503 ++++---- src/google/adk_community/sessions/utils.py | 28 +- tests/unittests/memory/__init__.py | 1 - .../memory/test_open_memory_service.py | 98 +- .../sessions/test_redis_session_service.py | 1093 +++++++++-------- 9 files changed, 926 insertions(+), 912 deletions(-) diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..98c0e33 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -15,4 +15,5 @@ from . import memory from . import sessions from . import version + __version__ = version.__version__ diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c..fbcf66b 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -21,4 +21,3 @@ "OpenMemoryService", "OpenMemoryServiceConfig", ] - diff --git a/src/google/adk_community/memory/open_memory_service.py b/src/google/adk_community/memory/open_memory_service.py index 92c1ae6..4e949cf 100644 --- a/src/google/adk_community/memory/open_memory_service.py +++ b/src/google/adk_community/memory/open_memory_service.py @@ -19,27 +19,27 @@ from typing import Optional from typing import TYPE_CHECKING -import httpx -from google.genai import types -from pydantic import BaseModel -from pydantic import Field -from typing_extensions import override - from google.adk.memory import _utils from google.adk.memory.base_memory_service import BaseMemoryService from google.adk.memory.base_memory_service import SearchMemoryResponse from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +import httpx +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override from .utils import extract_text_from_event if TYPE_CHECKING: from google.adk.sessions.session import Session -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) + class OpenMemoryService(BaseMemoryService): """Memory service implementation using OpenMemory. - + See https://openmemory.cavira.app/ for more information. """ @@ -55,7 +55,7 @@ def __init__( base_url: Base URL of the OpenMemory instance (default: http://localhost:3000). api_key: API key for authentication. **Required** - must be provided. config: OpenMemoryServiceConfig instance. If None, uses defaults. - + Raises: ValueError: If api_key is not provided or is empty. """ @@ -64,7 +64,7 @@ def __init__( "api_key is required for OpenMemory. " "Provide an API key when initializing OpenMemoryService." ) - self._base_url = base_url.rstrip('/') + self._base_url = base_url.rstrip("/") self._api_key = api_key self._config = config or OpenMemoryServiceConfig() @@ -81,14 +81,12 @@ def _determine_salience(self, author: Optional[str]) -> float: else: return self._config.default_salience - def _prepare_memory_data( - self, event, content_text: str, session - ) -> dict: + def _prepare_memory_data(self, event, content_text: str, session) -> dict: """Prepare memory data structure for OpenMemory API.""" timestamp_str = None if event.timestamp: timestamp_str = _utils.format_timestamp(event.timestamp) - + # Embed author and timestamp in content for search retrieval # Format: [Author: user, Time: 2025-11-04T10:32:01] Content text enriched_content = content_text @@ -97,11 +95,11 @@ def _prepare_memory_data( metadata_parts.append(f"Author: {event.author}") if timestamp_str: metadata_parts.append(f"Time: {timestamp_str}") - + if metadata_parts: metadata_prefix = "[" + ", ".join(metadata_parts) + "] " enriched_content = metadata_prefix + content_text - + metadata = { "app_name": session.app_name, "user_id": session.user_id, @@ -110,13 +108,13 @@ def _prepare_memory_data( "invocation_id": event.invocation_id, "author": event.author, "timestamp": event.timestamp, - "source": "adk_session" + "source": "adk_session", } - + memory_data = { "content": enriched_content, "metadata": metadata, - "salience": self._determine_salience(event.author) + "salience": self._determine_salience(event.author), } if self._config.enable_metadata_tags: @@ -138,7 +136,7 @@ async def add_session_to_memory(self, session: Session): async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } for event in session.events: @@ -155,16 +153,14 @@ async def add_session_to_memory(self, session: Session): "tags": memory_data.get("tags", []), "metadata": memory_data.get("metadata", {}), "salience": memory_data.get("salience", 0.5), - "user_id": session.user_id + "user_id": session.user_id, } - + response = await http_client.post( - f"{self._base_url}/memory/add", - json=payload, - headers=headers + f"{self._base_url}/memory/add", json=payload, headers=headers ) response.raise_for_status() - + memories_added += 1 logger.debug("Added memory for event %s", event.id) except httpx.HTTPStatusError as e: @@ -176,24 +172,24 @@ async def add_session_to_memory(self, session: Session): ) except httpx.RequestError as e: logger.error( - "Failed to add memory for event %s due to request error: %s", event.id, e + "Failed to add memory for event %s due to request error: %s", + event.id, + e, ) except Exception as e: - logger.error("Failed to add memory for event %s due to unexpected error: %s", event.id, e) + logger.error( + "Failed to add memory for event %s due to unexpected error: %s", + event.id, + e, + ) - logger.info( - "Added %d memories from session %s", memories_added, session.id - ) + logger.info("Added %d memories from session %s", memories_added, session.id) def _build_search_payload( self, app_name: str, user_id: str, query: str ) -> dict: """Build search payload for OpenMemory query API.""" - payload = { - "query": query, - "k": self._config.search_top_k, - "filter": {} - } + payload = {"query": query, "k": self._config.search_top_k, "filter": {}} payload["filter"]["user_id"] = user_id @@ -204,7 +200,7 @@ def _build_search_payload( def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: """Convert OpenMemory result to MemoryEntry. - + Extracts author and timestamp from enriched content format: [Author: user, Time: 2025-11-04T10:32:01] Content text """ @@ -213,28 +209,24 @@ def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: author = None timestamp = None clean_content = raw_content - + # Parse enriched content format to extract metadata - match = re.match(r'^\[([^\]]+)\]\s+(.*)', raw_content, re.DOTALL) + match = re.match(r"^\[([^\]]+)\]\s+(.*)", raw_content, re.DOTALL) if match: metadata_str = match.group(1) clean_content = match.group(2) - - author_match = re.search(r'Author:\s*([^,\]]+)', metadata_str) + + author_match = re.search(r"Author:\s*([^,\]]+)", metadata_str) if author_match: author = author_match.group(1).strip() - - time_match = re.search(r'Time:\s*([^,\]]+)', metadata_str) + + time_match = re.search(r"Time:\s*([^,\]]+)", metadata_str) if time_match: timestamp = time_match.group(1).strip() - + content = types.Content(parts=[types.Part(text=clean_content)]) - return MemoryEntry( - content=content, - author=author, - timestamp=timestamp - ) + return MemoryEntry(content=content, author=author, timestamp=timestamp) except (KeyError, ValueError) as e: logger.debug("Failed to convert result to MemoryEntry: %s", e) return None @@ -247,25 +239,27 @@ async def search_memory( try: search_payload = self._build_search_payload(app_name, user_id, query) memories = [] - + async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } - + logger.debug("Query payload: %s", search_payload) - + response = await http_client.post( f"{self._base_url}/memory/query", json=search_payload, - headers=headers + headers=headers, ) response.raise_for_status() result = response.json() - - logger.debug("Query returned %d matches", len(result.get("matches", []))) - + + logger.debug( + "Query returned %d matches", len(result.get("matches", [])) + ) + for match in result.get("matches", []): memory_entry = self._convert_to_memory_entry(match) if memory_entry: diff --git a/src/google/adk_community/memory/utils.py b/src/google/adk_community/memory/utils.py index 0b78206..48058ce 100644 --- a/src/google/adk_community/memory/utils.py +++ b/src/google/adk_community/memory/utils.py @@ -33,9 +33,8 @@ def extract_text_from_event(event) -> str: # Filter out thought parts and only extract text # This prevents metadata like thoughtSignature from being stored text_parts = [ - part.text - for part in event.content.parts + part.text + for part in event.content.parts if part.text and not part.thought ] return ' '.join(text_parts) - diff --git a/src/google/adk_community/sessions/redis_session_service.py b/src/google/adk_community/sessions/redis_session_service.py index bd8e289..075b342 100644 --- a/src/google/adk_community/sessions/redis_session_service.py +++ b/src/google/adk_community/sessions/redis_session_service.py @@ -17,23 +17,21 @@ import bisect import logging import time +from typing import Any +from typing import Optional import uuid -from typing import Any, Optional +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.adk.sessions.state import State import orjson import redis.asyncio as redis from redis.crc import key_slot from typing_extensions import override -from google.adk.events.event import Event -from google.adk.sessions.base_session_service import ( - BaseSessionService, - GetSessionConfig, - ListSessionsResponse, -) -from google.adk.sessions.session import Session -from google.adk.sessions.state import State - from .utils import _json_serializer logger = logging.getLogger("google_adk." + __name__) @@ -42,256 +40,261 @@ def _session_serializer(obj: Session) -> bytes: - """Serialize ADK Session to JSON bytes.""" - return orjson.dumps(obj.model_dump(), default=_json_serializer) + """Serialize ADK Session to JSON bytes.""" + return orjson.dumps(obj.model_dump(), default=_json_serializer) class RedisKeys: - """Helper to generate Redis keys consistently.""" + """Helper to generate Redis keys consistently.""" - @staticmethod - def session(session_id: str) -> str: - return f"session:{session_id}" + @staticmethod + def session(session_id: str) -> str: + return f"session:{session_id}" - @staticmethod - def user_sessions(app_name: str, user_id: str) -> str: - return f"{State.APP_PREFIX}:{app_name}:{user_id}" + @staticmethod + def user_sessions(app_name: str, user_id: str) -> str: + return f"{State.APP_PREFIX}:{app_name}:{user_id}" - @staticmethod - def app_state(app_name: str) -> str: - return f"{State.APP_PREFIX}{app_name}" + @staticmethod + def app_state(app_name: str) -> str: + return f"{State.APP_PREFIX}{app_name}" - @staticmethod - def user_state(app_name: str, user_id: str) -> str: - return f"{State.USER_PREFIX}{app_name}:{user_id}" + @staticmethod + def user_state(app_name: str, user_id: str) -> str: + return f"{State.USER_PREFIX}{app_name}:{user_id}" class RedisSessionService(BaseSessionService): - """A Redis-backed implementation of the session service.""" - - def __init__( - self, - host="localhost", - port=6379, - db=0, - uri=None, - cluster_uri=None, - expire=DEFAULT_EXPIRATION, - **kwargs, - ): - self.expire = expire - - if cluster_uri: - self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) - elif uri: - self.cache = redis.Redis.from_url(uri, **kwargs) - else: - self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) - - async def health_check(self) -> bool: - try: - await self.cache.ping() - return True - except redis.RedisError: - return False - - @override - async def create_session( - self, - *, - app_name: str, - user_id: str, - state: Optional[dict[str, Any]] = None, - session_id: Optional[str] = None, - ) -> Session: - session_id = ( - session_id.strip() - if session_id and session_id.strip() - else str(uuid.uuid4()) - ) - session = Session( - app_name=app_name, - user_id=user_id, - id=session_id, - state=state or {}, - last_update_time=time.time(), - ) - - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.sadd(user_sessions_key, session_id) - pipe.expire(user_sessions_key, self.expire) - pipe.set( - session_key, - _session_serializer(session), - ex=self.expire, + """A Redis-backed implementation of the session service.""" + + def __init__( + self, + host="localhost", + port=6379, + db=0, + uri=None, + cluster_uri=None, + expire=DEFAULT_EXPIRATION, + **kwargs, + ): + self.expire = expire + + if cluster_uri: + self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) + elif uri: + self.cache = redis.Redis.from_url(uri, **kwargs) + else: + self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) + + async def health_check(self) -> bool: + try: + await self.cache.ping() + return True + except redis.RedisError: + return False + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=state or {}, + last_update_time=time.time(), + ) + + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.sadd(user_sessions_key, session_id) + pipe.expire(user_sessions_key, self.expire) + pipe.set( + session_key, + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return await self._merge_state(app_name, user_id, session) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + session_key = RedisKeys.session(session_id) + raw_session = await self.cache.get(session_key) + if not raw_session: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + await self.cache.srem(user_sessions_key, session_id) + return None + + try: + session_dict = orjson.loads(raw_session) + session = Session.model_validate(session_dict) + except (orjson.JSONDecodeError, Exception) as e: + logger.error(f"Error decoding session {session_id}: {e}") + return None + + if config: + if config.num_recent_events: + session.events = session.events[-config.num_recent_events :] + if config.after_timestamp: + timestamps = [e.timestamp for e in session.events] + start_index = bisect.bisect_left(timestamps, config.after_timestamp) + session.events = session.events[start_index:] + + return await self._merge_state(app_name, user_id, session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + sessions = await self._load_sessions(app_name, user_id) + sessions_without_events = [] + + for session_data in sessions.values(): + session = Session.model_validate(session_data) + session.events = [] + session.state = {} + sessions_without_events.append(session) + + return ListSessionsResponse(sessions=sessions_without_events) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.srem(user_sessions_key, session_id) + pipe.delete(session_key) + await pipe.execute() + + @override + async def append_event(self, session: Session, event: Event) -> Event: + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + async with self.cache.pipeline(transaction=False) as pipe: + user_sessions_key = RedisKeys.user_sessions( + session.app_name, session.user_id + ) + pipe.expire(user_sessions_key, self.expire) + + if event.actions and event.actions.state_delta: + for key, value in event.actions.state_delta.items(): + if key.startswith(State.APP_PREFIX): + pipe.hset( + RedisKeys.app_state(session.app_name), + key.removeprefix(State.APP_PREFIX), + orjson.dumps(value), ) - await pipe.execute() - - return await self._merge_state(app_name, user_id, session) - - @override - async def get_session( - self, - *, - app_name: str, - user_id: str, - session_id: str, - config: Optional[GetSessionConfig] = None, - ) -> Optional[Session]: - session_key = RedisKeys.session(session_id) - raw_session = await self.cache.get(session_key) - if not raw_session: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - await self.cache.srem(user_sessions_key, session_id) - return None - - try: - session_dict = orjson.loads(raw_session) - session = Session.model_validate(session_dict) - except (orjson.JSONDecodeError, Exception) as e: - logger.error(f"Error decoding session {session_id}: {e}") - return None - - if config: - if config.num_recent_events: - session.events = session.events[-config.num_recent_events :] - if config.after_timestamp: - timestamps = [e.timestamp for e in session.events] - start_index = bisect.bisect_left(timestamps, config.after_timestamp) - session.events = session.events[start_index:] - - return await self._merge_state(app_name, user_id, session) - - @override - async def list_sessions( - self, *, app_name: str, user_id: str - ) -> ListSessionsResponse: - sessions = await self._load_sessions(app_name, user_id) - sessions_without_events = [] - - for session_data in sessions.values(): - session = Session.model_validate(session_data) - session.events = [] - session.state = {} - sessions_without_events.append(session) - - return ListSessionsResponse(sessions=sessions_without_events) - - @override - async def delete_session( - self, *, app_name: str, user_id: str, session_id: str - ) -> None: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.srem(user_sessions_key, session_id) - pipe.delete(session_key) - await pipe.execute() - - @override - async def append_event(self, session: Session, event: Event) -> Event: - await super().append_event(session=session, event=event) - session.last_update_time = event.timestamp - - async with self.cache.pipeline(transaction=False) as pipe: - user_sessions_key = RedisKeys.user_sessions( - session.app_name, session.user_id - ) - pipe.expire(user_sessions_key, self.expire) - - if event.actions and event.actions.state_delta: - for key, value in event.actions.state_delta.items(): - if key.startswith(State.APP_PREFIX): - pipe.hset( - RedisKeys.app_state(session.app_name), - key.removeprefix(State.APP_PREFIX), - orjson.dumps(value), - ) - if key.startswith(State.USER_PREFIX): - pipe.hset( - RedisKeys.user_state(session.app_name, session.user_id), - key.removeprefix(State.USER_PREFIX), - orjson.dumps(value), - ) - - pipe.set( - RedisKeys.session(session.id), - _session_serializer(session), - ex=self.expire, - ) - await pipe.execute() - - return event - - async def _merge_state( - self, app_name: str, user_id: str, session: Session - ) -> Session: - app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) - for k, v in app_state.items(): - session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) - - user_state = await self.cache.hgetall(RedisKeys.user_state(app_name, user_id)) - for k, v in user_state.items(): - session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) - - return session - - async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]: - key = RedisKeys.user_sessions(app_name, user_id) - try: - session_ids_bytes = await self.cache.smembers(key) - if not session_ids_bytes: - return {} - - session_ids = [s.decode() for s in session_ids_bytes] - session_keys = [RedisKeys.session(sid) for sid in session_ids] - - # Group by slot for Redis Cluster - slot_groups: dict[int, list[str]] = {} - for k in session_keys: - slot = key_slot(k.encode()) - slot_groups.setdefault(slot, []).append(k) - - async def fetch_group(keys: list[str]): - async with self.cache.pipeline(transaction=False) as pipe: - for k in keys: - pipe.get(k) - return await pipe.execute() - - results_per_group = await asyncio.gather( - *(fetch_group(keys) for keys in slot_groups.values()) + if key.startswith(State.USER_PREFIX): + pipe.hset( + RedisKeys.user_state(session.app_name, session.user_id), + key.removeprefix(State.USER_PREFIX), + orjson.dumps(value), ) - raw_sessions = [] - for group_keys, group_results in zip( - slot_groups.values(), results_per_group - ): - raw_sessions.extend(zip(group_keys, group_results)) - - sessions = {} - sessions_to_cleanup = [] - for key_name, raw_session in raw_sessions: - session_id = key_name.split(":", 1)[1] - if raw_session: - try: - sessions[session_id] = orjson.loads(raw_session) - except orjson.JSONDecodeError as e: - logger.error(f"Error decoding session {session_id}: {e}") - else: - logger.warning( - "Session ID %s found in user set but session data is missing. Cleaning up.", - session_id, - ) - sessions_to_cleanup.append(session_id) - - if sessions_to_cleanup: - await self.cache.srem(key, *sessions_to_cleanup) - - return sessions - except redis.RedisError as e: - logger.error(f"Error loading sessions for {user_id}: {e}") - return {} + pipe.set( + RedisKeys.session(session.id), + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return event + + async def _merge_state( + self, app_name: str, user_id: str, session: Session + ) -> Session: + app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) + for k, v in app_state.items(): + session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) + + user_state = await self.cache.hgetall( + RedisKeys.user_state(app_name, user_id) + ) + for k, v in user_state.items(): + session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) + + return session + + async def _load_sessions( + self, app_name: str, user_id: str + ) -> dict[str, dict]: + key = RedisKeys.user_sessions(app_name, user_id) + try: + session_ids_bytes = await self.cache.smembers(key) + if not session_ids_bytes: + return {} + + session_ids = [s.decode() for s in session_ids_bytes] + session_keys = [RedisKeys.session(sid) for sid in session_ids] + + # Group by slot for Redis Cluster + slot_groups: dict[int, list[str]] = {} + for k in session_keys: + slot = key_slot(k.encode()) + slot_groups.setdefault(slot, []).append(k) + + async def fetch_group(keys: list[str]): + async with self.cache.pipeline(transaction=False) as pipe: + for k in keys: + pipe.get(k) + return await pipe.execute() + + results_per_group = await asyncio.gather( + *(fetch_group(keys) for keys in slot_groups.values()) + ) + + raw_sessions = [] + for group_keys, group_results in zip( + slot_groups.values(), results_per_group + ): + raw_sessions.extend(zip(group_keys, group_results)) + + sessions = {} + sessions_to_cleanup = [] + for key_name, raw_session in raw_sessions: + session_id = key_name.split(":", 1)[1] + if raw_session: + try: + sessions[session_id] = orjson.loads(raw_session) + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding session {session_id}: {e}") + else: + logger.warning( + "Session ID %s found in user set but session data is missing." + " Cleaning up.", + session_id, + ) + sessions_to_cleanup.append(session_id) + + if sessions_to_cleanup: + await self.cache.srem(key, *sessions_to_cleanup) + + return sessions + except redis.RedisError as e: + logger.error(f"Error loading sessions for {user_id}: {e}") + return {} diff --git a/src/google/adk_community/sessions/utils.py b/src/google/adk_community/sessions/utils.py index bc53d2b..132c773 100644 --- a/src/google/adk_community/sessions/utils.py +++ b/src/google/adk_community/sessions/utils.py @@ -20,18 +20,18 @@ def _json_serializer(obj): - """Fallback serializer to handle non-JSON-compatible types.""" - if isinstance(obj, set): - return list(obj) - if isinstance(obj, bytes): - try: - return base64.b64encode(obj).decode("ascii") - except Exception: - return repr(obj) - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - if isinstance(obj, Decimal): - return float(obj) + """Fallback serializer to handle non-JSON-compatible types.""" + if isinstance(obj, set): + return list(obj) + if isinstance(obj, bytes): + try: + return base64.b64encode(obj).decode("ascii") + except Exception: + return repr(obj) + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, uuid.UUID): return str(obj) + if isinstance(obj, Decimal): + return float(obj) + return str(obj) diff --git a/tests/unittests/memory/__init__.py b/tests/unittests/memory/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/memory/__init__.py +++ b/tests/unittests/memory/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/memory/test_open_memory_service.py b/tests/unittests/memory/test_open_memory_service.py index 74cb05e..718f6c6 100644 --- a/tests/unittests/memory/test_open_memory_service.py +++ b/tests/unittests/memory/test_open_memory_service.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import patch from google.adk.events.event import Event -from google.adk_community.memory.open_memory_service import ( - OpenMemoryService, - OpenMemoryServiceConfig, -) from google.adk.sessions.session import Session from google.genai import types import pytest +from google.adk_community.memory.open_memory_service import OpenMemoryService +from google.adk_community.memory.open_memory_service import OpenMemoryServiceConfig + MOCK_APP_NAME = 'test-app' MOCK_USER_ID = 'test-user' MOCK_SESSION_ID = 'session-1' @@ -39,7 +39,9 @@ invocation_id='inv-1', author='user', timestamp=12345, - content=types.Content(parts=[types.Part(text='Hello, I like Python.')]), + content=types.Content( + parts=[types.Part(text='Hello, I like Python.')] + ), ), Event( id='event-2', @@ -47,7 +49,9 @@ author='model', timestamp=12346, content=types.Content( - parts=[types.Part(text='Python is a great programming language.')] + parts=[ + types.Part(text='Python is a great programming language.') + ] ), ), # Empty event, should be ignored @@ -85,10 +89,12 @@ @pytest.fixture def mock_httpx_client(): """Mock httpx.AsyncClient for testing.""" - with patch('google.adk_community.memory.open_memory_service.httpx.AsyncClient') as mock_client_class: + with patch( + 'google.adk_community.memory.open_memory_service.httpx.AsyncClient' + ) as mock_client_class: mock_client = MagicMock() mock_response = MagicMock() - mock_response.json.return_value = {"matches": []} + mock_response.json.return_value = {'matches': []} mock_response.raise_for_status = MagicMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) @@ -107,14 +113,10 @@ def memory_service(mock_httpx_client): def memory_service_with_config(mock_httpx_client): """Create OpenMemoryService with custom config.""" config = OpenMemoryServiceConfig( - search_top_k=5, - user_content_salience=0.9, - model_content_salience=0.6 + search_top_k=5, user_content_salience=0.9, model_content_salience=0.6 ) return OpenMemoryService( - base_url='http://localhost:3000', - api_key='test-key', - config=config + base_url='http://localhost:3000', api_key='test-key', config=config ) @@ -139,7 +141,7 @@ def test_custom_config(self): user_content_salience=0.9, model_content_salience=0.75, default_salience=0.5, - enable_metadata_tags=False + enable_metadata_tags=False, ) assert config.search_top_k == 20 assert config.timeout == 10.0 @@ -158,18 +160,20 @@ def test_config_validation_search_top_k(self): def test_api_key_required(self): """Test that API key is required.""" - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000", api_key="") - - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000") + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000', api_key='') + + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000') class TestOpenMemoryService: """Tests for OpenMemoryService.""" @pytest.mark.asyncio - async def test_add_session_to_memory_success(self, memory_service, mock_httpx_client): + async def test_add_session_to_memory_success( + self, memory_service, mock_httpx_client + ): """Test successful addition of session memories.""" await memory_service.add_session_to_memory(MOCK_SESSION) @@ -220,9 +224,7 @@ async def test_add_session_uses_config_salience( assert request_data['salience'] == 0.6 # Custom model salience @pytest.mark.asyncio - async def test_add_session_without_metadata_tags( - self, mock_httpx_client - ): + async def test_add_session_without_metadata_tags(self, mock_httpx_client): """Test adding memories without metadata tags.""" config = OpenMemoryServiceConfig(enable_metadata_tags=False) memory_service = OpenMemoryService( @@ -236,7 +238,9 @@ async def test_add_session_without_metadata_tags( assert request_data.get('tags', []) == [] @pytest.mark.asyncio - async def test_add_session_error_handling(self, memory_service, mock_httpx_client): + async def test_add_session_error_handling( + self, memory_service, mock_httpx_client + ): """Test error handling during memory addition.""" mock_httpx_client.post.side_effect = Exception('API Error') @@ -254,20 +258,23 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): mock_response.json.return_value = { 'matches': [ { - 'content': '[Author: user, Time: 2025-01-01T00:00:00] Python is great', + 'content': ( + '[Author: user, Time: 2025-01-01T00:00:00] Python is great' + ), }, { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like' + ' programming' + ), + }, ] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='Python programming' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Python programming' ) # Verify API call @@ -276,7 +283,7 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): assert request_data['query'] == 'Python programming' assert request_data['k'] == 10 assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Verify results (content should be cleaned of metadata prefix) assert len(result.memories) == 2 @@ -293,26 +300,24 @@ async def test_search_memory_applies_filters( # Mock response - server-side filtering ensures only matching results mock_response = MagicMock() mock_response.json.return_value = { - 'matches': [ - { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } - ] + 'matches': [{ + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like programming' + ), + }] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Verify filters were passed correctly call_args = mock_httpx_client.post.call_args request_data = call_args.kwargs['json'] assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Should return filtered results assert len(result.memories) == 1 @@ -329,9 +334,7 @@ async def test_search_memory_respects_top_k( mock_httpx_client.post = AsyncMock(return_value=mock_response) await memory_service_with_config.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) call_args = mock_httpx_client.post.call_args @@ -346,11 +349,8 @@ async def test_search_memory_error_handling( mock_httpx_client.post.side_effect = Exception('API Error') result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Should return empty results on error assert len(result.memories) == 0 - diff --git a/tests/unittests/sessions/test_redis_session_service.py b/tests/unittests/sessions/test_redis_session_service.py index dad0867..2354561 100644 --- a/tests/unittests/sessions/test_redis_session_service.py +++ b/tests/unittests/sessions/test_redis_session_service.py @@ -12,549 +12,568 @@ # See the License for the specific language governing permissions and # limitations under the License. -import orjson -from datetime import datetime, timezone -import pytest -import pytest_asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime +from datetime import timezone +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.sessions.base_session_service import GetSessionConfig -from google.adk_community.sessions.redis_session_service import RedisSessionService from google.genai import types +import orjson +import pytest +import pytest_asyncio +from google.adk_community.sessions.redis_session_service import RedisSessionService -class TestRedisSessionService: - """Test cases for RedisSessionService.""" - - @pytest_asyncio.fixture - async def redis_service(self): - """Create a Redis session service for testing.""" - with patch("redis.asyncio.Redis") as mock_redis: - mock_client = AsyncMock() - mock_redis.return_value = mock_client - service = RedisSessionService() - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_service(self): - """Create a Redis cluster session service for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://redis-node1:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_uri_service(self): - """Create a Redis cluster session service using URI for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://node1:6379,node2:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - def _setup_redis_mocks(self, redis_service, sessions_data=None): - """Helper to set up Redis mocks for the new storage strategy.""" - if sessions_data is None: - sessions_data = {} - - session_ids = list(sessions_data.keys()) - redis_service.cache.smembers = AsyncMock( - return_value={sid.encode() for sid in session_ids} - ) - - # Mock the new cluster-aware pipeline approach - session_values = [ - orjson.dumps(sessions_data[sid]) if sid in sessions_data else None - for sid in session_ids - ] - # For backward compatibility with mget approach (still used in some tests) - redis_service.cache.mget = AsyncMock(return_value=session_values) - - # Mock pipeline for the new cluster approach - if session_ids: - # Group sessions as the actual implementation does - results_per_group = [] - for i in range(len(session_ids)): - results_per_group.append([session_values[i]]) - - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(side_effect=results_per_group) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - else: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.srem = AsyncMock() - redis_service.cache.get = AsyncMock(return_value=None) # Default to no session - - # Additional pipeline operations for create/update operations - if not session_ids: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining - mock_pipe.sadd = MagicMock(return_value=mock_pipe) - mock_pipe.expire = MagicMock(return_value=mock_pipe) - mock_pipe.delete = MagicMock(return_value=mock_pipe) - mock_pipe.srem = MagicMock(return_value=mock_pipe) - mock_pipe.hset = MagicMock(return_value=mock_pipe) - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.hgetall = AsyncMock(return_value={}) - redis_service.cache.hset = AsyncMock() - - @pytest.mark.asyncio - async def test_get_empty_session(self, redis_service): - """Test getting a non-existent session.""" - self._setup_redis_mocks(redis_service) - - session = await redis_service.get_session( - app_name="test_app", user_id="test_user", session_id="nonexistent" - ) - - assert session is None - - @pytest.mark.asyncio - async def test_create_get_session(self, redis_service): - """Test session creation and retrieval.""" - app_name = "test_app" - user_id = "test_user" - state = {"key": "value"} - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Allow tiny float/clock rounding differences (~1ms) - assert ( - session.last_update_time - <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 - ) - - # Mock individual session retrieval - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_create_and_list_sessions(self, redis_service): - """Test creating multiple sessions and listing them. - - list_sessions() is expected to return lightweight session summaries, - i.e., with events and state stripped for performance. - """ - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session_ids = ["session" + str(i) for i in range(3)] - sessions_data = {} - - for i, session_id in enumerate(session_ids): - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"key": "value" + session_id}, - ) - # Add at least one event to ensure list_sessions actually strips them. - session.events.append(Event(author="user", timestamp=float(i + 1))) - sessions_data[session_id] = session.model_dump() - - # Now mock Redis to return those sessions (with events present in storage) - self._setup_redis_mocks(redis_service, sessions_data) - - list_sessions_response = await redis_service.list_sessions( - app_name=app_name, user_id=user_id - ) - sessions = list_sessions_response.sessions - - assert len(sessions) == len(session_ids) - returned_session_ids = {s.id for s in sessions} - assert returned_session_ids == set(session_ids) - - for s in sessions: - # list_sessions returns summaries: events and state removed for perf. - assert len(s.events) == 0 - assert s.state == {} - - @pytest.mark.asyncio - async def test_session_state_management(self, redis_service): - """Test session state management with app, user, and temp state.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"initial_key": "initial_value"}, - ) - - event = Event( - invocation_id="invocation", - author="user", - content=types.Content(role="user", parts=[types.Part(text="text")]), - actions=EventActions( - state_delta={ - "app:key": "app_value", - "user:key1": "user_value", - "temp:key": "temp_value", - "initial_key": "updated_value", - } +class TestRedisSessionService: + """Test cases for RedisSessionService.""" + + @pytest_asyncio.fixture + async def redis_service(self): + """Create a Redis session service for testing.""" + with patch("redis.asyncio.Redis") as mock_redis: + mock_client = AsyncMock() + mock_redis.return_value = mock_client + service = RedisSessionService() + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_service(self): + """Create a Redis cluster session service for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://redis-node1:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_uri_service(self): + """Create a Redis cluster session service using URI for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://node1:6379,node2:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + def _setup_redis_mocks(self, redis_service, sessions_data=None): + """Helper to set up Redis mocks for the new storage strategy.""" + if sessions_data is None: + sessions_data = {} + + session_ids = list(sessions_data.keys()) + redis_service.cache.smembers = AsyncMock( + return_value={sid.encode() for sid in session_ids} + ) + + # Mock the new cluster-aware pipeline approach + session_values = [ + orjson.dumps(sessions_data[sid]) if sid in sessions_data else None + for sid in session_ids + ] + + # For backward compatibility with mget approach (still used in some tests) + redis_service.cache.mget = AsyncMock(return_value=session_values) + + # Mock pipeline for the new cluster approach + if session_ids: + # Group sessions as the actual implementation does + results_per_group = [] + for i in range(len(session_ids)): + results_per_group.append([session_values[i]]) + + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(side_effect=results_per_group) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + else: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.srem = AsyncMock() + redis_service.cache.get = AsyncMock( + return_value=None + ) # Default to no session + + # Additional pipeline operations for create/update operations + if not session_ids: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining + mock_pipe.sadd = MagicMock(return_value=mock_pipe) + mock_pipe.expire = MagicMock(return_value=mock_pipe) + mock_pipe.delete = MagicMock(return_value=mock_pipe) + mock_pipe.srem = MagicMock(return_value=mock_pipe) + mock_pipe.hset = MagicMock(return_value=mock_pipe) + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.hgetall = AsyncMock(return_value={}) + redis_service.cache.hset = AsyncMock() + + @pytest.mark.asyncio + async def test_get_empty_session(self, redis_service): + """Test getting a non-existent session.""" + self._setup_redis_mocks(redis_service) + + session = await redis_service.get_session( + app_name="test_app", user_id="test_user", session_id="nonexistent" + ) + + assert session is None + + @pytest.mark.asyncio + async def test_create_get_session(self, redis_service): + """Test session creation and retrieval.""" + app_name = "test_app" + user_id = "test_user" + state = {"key": "value"} + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Allow tiny float/clock rounding differences (~1ms) + assert ( + session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 + ) + + # Mock individual session retrieval + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_create_and_list_sessions(self, redis_service): + """Test creating multiple sessions and listing them. + + list_sessions() is expected to return lightweight session summaries, + i.e., with events and state stripped for performance. + """ + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session_ids = ["session" + str(i) for i in range(3)] + sessions_data = {} + + for i, session_id in enumerate(session_ids): + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"key": "value" + session_id}, + ) + # Add at least one event to ensure list_sessions actually strips them. + session.events.append(Event(author="user", timestamp=float(i + 1))) + sessions_data[session_id] = session.model_dump() + + # Now mock Redis to return those sessions (with events present in storage) + self._setup_redis_mocks(redis_service, sessions_data) + + list_sessions_response = await redis_service.list_sessions( + app_name=app_name, user_id=user_id + ) + sessions = list_sessions_response.sessions + + assert len(sessions) == len(session_ids) + returned_session_ids = {s.id for s in sessions} + assert returned_session_ids == set(session_ids) + + for s in sessions: + # list_sessions returns summaries: events and state removed for perf. + assert len(s.events) == 0 + assert s.state == {} + + @pytest.mark.asyncio + async def test_session_state_management(self, redis_service): + """Test session state management with app, user, and temp state.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"initial_key": "initial_value"}, + ) + + event = Event( + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="text")]), + actions=EventActions( + state_delta={ + "app:key": "app_value", + "user:key1": "user_value", + "temp:key": "temp_value", + "initial_key": "updated_value", + } + ), + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + assert session.state.get("app:key") == "app_value" + assert session.state.get("user:key1") == "user_value" + assert session.state.get("initial_key") == "updated_value" + assert session.state.get("temp:key") is None # Temp state filtered + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.hset.assert_any_call( + "app:test_app", "key", orjson.dumps("app_value") + ) + pipe_mock.hset.assert_any_call( + "user:test_app:test_user", "key1", orjson.dumps("user_value") + ) + + @pytest.mark.asyncio + async def test_append_event_with_bytes(self, redis_service): + """Test appending events with binary content and serialization roundtrip.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + test_content = types.Content( + role="user", + parts=[ + types.Part.from_bytes( + data=b"test_image_data", mime_type="image/png" ), - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - assert session.state.get("app:key") == "app_value" - assert session.state.get("user:key1") == "user_value" - assert session.state.get("initial_key") == "updated_value" - assert session.state.get("temp:key") is None # Temp state filtered - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.hset.assert_any_call("app:test_app", "key", orjson.dumps("app_value")) - pipe_mock.hset.assert_any_call( - "user:test_app:test_user", "key1", orjson.dumps("user_value") - ) - - @pytest.mark.asyncio - async def test_append_event_with_bytes(self, redis_service): - """Test appending events with binary content and serialization roundtrip.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - test_content = types.Content( - role="user", - parts=[ - types.Part.from_bytes( - data=b"test_image_data", mime_type="image/png" - ), - ], - ) - test_grounding_metadata = types.GroundingMetadata( - search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") - ) - event = Event( - invocation_id="invocation", - author="user", - content=test_content, - grounding_metadata=test_grounding_metadata, - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - # Verify the event was appended to in-memory session - assert len(session.events) == 1 - assert session.events[0].content == test_content - assert session.events[0].grounding_metadata == test_grounding_metadata - - # Test serialization/deserialization roundtrip to ensure binary data is preserved - # Simulate what happens when session is stored and retrieved from Redis - serialized_session = session.model_dump_json() - - redis_service.cache.get = AsyncMock(return_value=serialized_session.encode()) - - retrieved_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert retrieved_session is not None - assert len(retrieved_session.events) == 1 - - # Verify the binary content was preserved through serialization - retrieved_event = retrieved_session.events[0] - assert retrieved_event.content.parts[0].inline_data.data == b"test_image_data" - assert ( - retrieved_event.content.parts[0].inline_data.mime_type - == "image/png" - ) - assert ( - retrieved_event.grounding_metadata.search_entry_point.sdk_blob - == b"test_sdk_blob" - ) - - @pytest.mark.asyncio - async def test_get_session_with_config(self, redis_service): - """Test getting session with configuration filters.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - # Add multiple events with different timestamps - num_test_events = 5 - for i in range(1, num_test_events + 1): - event = Event(author="user", timestamp=float(i)) - session.events.append(event) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - # Test num_recent_events filter - config = GetSessionConfig(num_recent_events=3) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 - assert filtered_session.events[0].timestamp == 3.0 # Last 3 events - - # Test after_timestamp filter - config = GetSessionConfig(after_timestamp=3.0) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 # Events 3, 4, 5 - assert filtered_session.events[0].timestamp == 3.0 - - @pytest.mark.asyncio - async def test_delete_session(self, redis_service): - """Test session deletion.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) # Empty sessions - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - redis_service.cache.pipeline.reset_mock() - self._setup_redis_mocks(redis_service) - - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - @pytest.mark.asyncio - async def test_cluster_health_check(self, redis_cluster_service): - """Test health check for Redis cluster.""" - redis_cluster_service.cache.ping = AsyncMock(return_value=True) - - result = await redis_cluster_service.health_check() - assert result is True - redis_cluster_service.cache.ping.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_health_check_failure(self, redis_cluster_service): - """Test health check failure for Redis cluster.""" - from redis import RedisError - - redis_cluster_service.cache.ping = AsyncMock( - side_effect=RedisError("Connection failed") - ) - - result = await redis_cluster_service.health_check() - assert result is False - - @pytest.mark.asyncio - async def test_cluster_create_and_get_session(self, redis_cluster_service): - """Test session creation and retrieval in cluster mode.""" - app_name = "cluster_test_app" - user_id = "cluster_test_user" - state = {"cluster_key": "cluster_value"} - - self._setup_redis_mocks(redis_cluster_service) - - session = await redis_cluster_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Mock individual session retrieval - redis_cluster_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_cluster_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_cluster_uri_initialization(self, redis_cluster_uri_service): - """Test Redis cluster initialization with URI.""" - assert redis_cluster_uri_service.cache is not None - - @pytest.mark.asyncio - async def test_cluster_error_handling(self, redis_cluster_service): - """Test error handling in cluster operations.""" - from redis import RedisError - - app_name = "test_app" - user_id = "test_user" - - # Mock Redis error during session loading - redis_cluster_service.cache.smembers = AsyncMock( - side_effect=RedisError("Cluster error") - ) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - assert len(sessions_response.sessions) == 0 - - @pytest.mark.asyncio - async def test_cluster_connection_validation(self): - """Test cluster connection validation during initialization.""" - cluster_uri = "redis://redis-node1:6379" - - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - - service = RedisSessionService(cluster_uri=cluster_uri) - assert service.cache is not None - mock_redis_cluster.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): - """Test session cleanup when corrupted data is found in cluster.""" - app_name = "test_app" - user_id = "test_user" - - # Setup mock with corrupted session data - valid_session_data = { - "app_name": "test_app", - "user_id": "test_user", - "id": "session1", - "state": {}, - "events": [], - "last_update_time": 1234567890, - } - redis_cluster_service.cache.smembers = AsyncMock( - return_value={b"session1", b"session2"} - ) - - # Mock the pipeline for cluster approach - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock( - side_effect=[ - [orjson.dumps(valid_session_data)], # session1 result - [None], # session2 result (missing) - ] - ) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_cluster_service.cache.pipeline = MagicMock( - return_value=mock_context_manager - ) - redis_cluster_service.cache.srem = AsyncMock() - redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - redis_cluster_service.cache.srem.assert_called() - assert len(sessions_response.sessions) == 1 - - @pytest.mark.asyncio - async def test_decode_responses_handling(self, redis_service): - """Test proper handling of decode_responses setting.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - # Test with bytes response (decode_responses=False) - session_data = ( - '{"app_name": "test_app", "user_id": "test_user", "id": "test_session", ' - '"state": {}, "events": [], "last_update_time": 1234567890}' - ) - redis_service.cache.get = AsyncMock(return_value=session_data.encode()) - redis_service.cache.hgetall = AsyncMock(return_value={}) - - session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - assert session is not None - assert session.app_name == app_name - assert session.user_id == user_id + ], + ) + test_grounding_metadata = types.GroundingMetadata( + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") + ) + event = Event( + invocation_id="invocation", + author="user", + content=test_content, + grounding_metadata=test_grounding_metadata, + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + # Verify the event was appended to in-memory session + assert len(session.events) == 1 + assert session.events[0].content == test_content + assert session.events[0].grounding_metadata == test_grounding_metadata + + # Test serialization/deserialization roundtrip to ensure binary data is preserved + # Simulate what happens when session is stored and retrieved from Redis + serialized_session = session.model_dump_json() + + redis_service.cache.get = AsyncMock( + return_value=serialized_session.encode() + ) + + retrieved_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) == 1 + + # Verify the binary content was preserved through serialization + retrieved_event = retrieved_session.events[0] + assert ( + retrieved_event.content.parts[0].inline_data.data == b"test_image_data" + ) + assert retrieved_event.content.parts[0].inline_data.mime_type == "image/png" + assert ( + retrieved_event.grounding_metadata.search_entry_point.sdk_blob + == b"test_sdk_blob" + ) + + @pytest.mark.asyncio + async def test_get_session_with_config(self, redis_service): + """Test getting session with configuration filters.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Add multiple events with different timestamps + num_test_events = 5 + for i in range(1, num_test_events + 1): + event = Event(author="user", timestamp=float(i)) + session.events.append(event) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + # Test num_recent_events filter + config = GetSessionConfig(num_recent_events=3) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 + assert filtered_session.events[0].timestamp == 3.0 # Last 3 events + + # Test after_timestamp filter + config = GetSessionConfig(after_timestamp=3.0) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 # Events 3, 4, 5 + assert filtered_session.events[0].timestamp == 3.0 + + @pytest.mark.asyncio + async def test_delete_session(self, redis_service): + """Test session deletion.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) # Empty sessions + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + redis_service.cache.pipeline.reset_mock() + self._setup_redis_mocks(redis_service) + + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + @pytest.mark.asyncio + async def test_cluster_health_check(self, redis_cluster_service): + """Test health check for Redis cluster.""" + redis_cluster_service.cache.ping = AsyncMock(return_value=True) + + result = await redis_cluster_service.health_check() + assert result is True + redis_cluster_service.cache.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_health_check_failure(self, redis_cluster_service): + """Test health check failure for Redis cluster.""" + from redis import RedisError + + redis_cluster_service.cache.ping = AsyncMock( + side_effect=RedisError("Connection failed") + ) + + result = await redis_cluster_service.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_cluster_create_and_get_session(self, redis_cluster_service): + """Test session creation and retrieval in cluster mode.""" + app_name = "cluster_test_app" + user_id = "cluster_test_user" + state = {"cluster_key": "cluster_value"} + + self._setup_redis_mocks(redis_cluster_service) + + session = await redis_cluster_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Mock individual session retrieval + redis_cluster_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_cluster_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_cluster_uri_initialization(self, redis_cluster_uri_service): + """Test Redis cluster initialization with URI.""" + assert redis_cluster_uri_service.cache is not None + + @pytest.mark.asyncio + async def test_cluster_error_handling(self, redis_cluster_service): + """Test error handling in cluster operations.""" + from redis import RedisError + + app_name = "test_app" + user_id = "test_user" + + # Mock Redis error during session loading + redis_cluster_service.cache.smembers = AsyncMock( + side_effect=RedisError("Cluster error") + ) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + assert len(sessions_response.sessions) == 0 + + @pytest.mark.asyncio + async def test_cluster_connection_validation(self): + """Test cluster connection validation during initialization.""" + cluster_uri = "redis://redis-node1:6379" + + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + + service = RedisSessionService(cluster_uri=cluster_uri) + assert service.cache is not None + mock_redis_cluster.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): + """Test session cleanup when corrupted data is found in cluster.""" + app_name = "test_app" + user_id = "test_user" + + # Setup mock with corrupted session data + valid_session_data = { + "app_name": "test_app", + "user_id": "test_user", + "id": "session1", + "state": {}, + "events": [], + "last_update_time": 1234567890, + } + redis_cluster_service.cache.smembers = AsyncMock( + return_value={b"session1", b"session2"} + ) + + # Mock the pipeline for cluster approach + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock( + side_effect=[ + [orjson.dumps(valid_session_data)], # session1 result + [None], # session2 result (missing) + ] + ) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_cluster_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + redis_cluster_service.cache.srem = AsyncMock() + redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + redis_cluster_service.cache.srem.assert_called() + assert len(sessions_response.sessions) == 1 + + @pytest.mark.asyncio + async def test_decode_responses_handling(self, redis_service): + """Test proper handling of decode_responses setting.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Test with bytes response (decode_responses=False) + session_data = ( + '{"app_name": "test_app", "user_id": "test_user", "id": "test_session",' + ' "state": {}, "events": [], "last_update_time": 1234567890}' + ) + redis_service.cache.get = AsyncMock(return_value=session_data.encode()) + redis_service.cache.hgetall = AsyncMock(return_value={}) + + session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.app_name == app_name + assert session.user_id == user_id From e0636fdc71e2ade01026b90e2bcd7565a61d0061 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Fri, 9 Jan 2026 19:49:48 +0200 Subject: [PATCH 08/16] feat(redis): add Pydantic config classes for Redis search tools Add configuration classes to group query-specific parameters: - RedisVectorQueryConfig: 15 parameters for vector similarity search - RedisRangeQueryConfig: 11 parameters for range-based search - RedisHybridQueryConfig: 22 parameters for hybrid text+vector search - RedisTextQueryConfig: 8 parameters for full-text search Each config class includes: - Pydantic validation with Field() constraints - ConfigDict(extra='forbid') to catch typos - to_query_kwargs() method for query construction - Comprehensive docstrings for each parameter --- .../adk_community/tools/redis/config.py | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 src/google/adk_community/tools/redis/config.py diff --git a/src/google/adk_community/tools/redis/config.py b/src/google/adk_community/tools/redis/config.py new file mode 100644 index 0000000..350aa05 --- /dev/null +++ b/src/google/adk_community/tools/redis/config.py @@ -0,0 +1,359 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for Redis search tools. + +This module provides Pydantic configuration classes for Redis search tools, +following the pattern established by BigQueryToolConfig in the upstream +google/adk-python repository and OpenMemoryServiceConfig in this repository. +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +# Type alias for sort specification +SortSpec = Optional[Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]]] + + +class RedisVectorQueryConfig(BaseModel): + """Configuration for Redis vector similarity search queries. + + This config groups all query-specific parameters for VectorQuery, + separating them from tool-level concerns like index and vectorizer. + + Attributes: + vector_field_name: Name of the vector field in the index. + num_results: Number of results to return (default: 10). + dtype: Data type of the vector (default: "float32"). + return_score: Whether to return vector distance scores. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + normalize_vector_distance: Convert distance to 0-1 similarity score. + hybrid_policy: Filter application policy - "BATCHES" or "ADHOC_BF". + batch_size: Batch size when hybrid_policy is "BATCHES". + ef_runtime: HNSW exploration factor at query time. + epsilon: Range search approximation factor for HNSW/SVS-VAMANA. + search_window_size: SVS-VAMANA search window size. + use_search_history: SVS-VAMANA history mode - "OFF", "ON", or "AUTO". + search_buffer_capacity: SVS-VAMANA 2-level compression tuning. + """ + + model_config = ConfigDict(extra="forbid") + + # Core query parameters + vector_field_name: str = Field(default="embedding") + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + normalize_vector_distance: bool = Field(default=False) + + # Version-dependent parameters (RedisVL 0.13.2+) + # These are excluded from query kwargs when None for backward compatibility + hybrid_policy: Optional[str] = Field(default=None) + batch_size: Optional[int] = Field(default=None, ge=1) + ef_runtime: Optional[int] = Field(default=None, ge=1) + epsilon: Optional[float] = Field(default=None, ge=0.0) + search_window_size: Optional[int] = Field(default=None, ge=1) + use_search_history: Optional[str] = Field(default=None) + search_buffer_capacity: Optional[int] = Field(default=None, ge=1) + + def to_query_kwargs( + self, vector: List[float], filter_expression: Optional[Any] = None + ) -> dict[str, Any]: + """Convert config to VectorQuery kwargs, excluding None version-dependent params. + + Args: + vector: The query vector embedding. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for VectorQuery constructor. + """ + # Core parameters always included + kwargs: dict[str, Any] = { + "vector": vector, + "vector_field_name": self.vector_field_name, + "num_results": self.num_results, + "dtype": self.dtype, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "normalize_vector_distance": self.normalize_vector_distance, + "filter_expression": filter_expression, + } + + # Version-dependent parameters: only include if not None + version_dependent = { + "hybrid_policy": self.hybrid_policy, + "batch_size": self.batch_size, + "ef_runtime": self.ef_runtime, + "epsilon": self.epsilon, + "search_window_size": self.search_window_size, + "use_search_history": self.use_search_history, + "search_buffer_capacity": self.search_buffer_capacity, + } + for key, value in version_dependent.items(): + if value is not None: + kwargs[key] = value + + return kwargs + + +class RedisRangeQueryConfig(BaseModel): + """Configuration for Redis vector range search queries. + + Range search finds all documents within a specified distance threshold + from the query vector, rather than returning a fixed number of results. + + Attributes: + vector_field_name: Name of the vector field in the index. + distance_threshold: Maximum distance for results (default: 0.2). + num_results: Maximum number of results to return. + dtype: Data type of the vector (default: "float32"). + return_score: Whether to return vector distance scores. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + normalize_vector_distance: Convert distance to 0-1 similarity score. + epsilon: Range search approximation factor for HNSW/SVS-VAMANA. + """ + + model_config = ConfigDict(extra="forbid") + + vector_field_name: str = Field(default="embedding") + distance_threshold: float = Field(default=0.2, ge=0.0) + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + normalize_vector_distance: bool = Field(default=False) + + # Version-dependent parameter + epsilon: Optional[float] = Field(default=None, ge=0.0) + + def to_query_kwargs( + self, vector: List[float], filter_expression: Optional[Any] = None + ) -> dict[str, Any]: + """Convert config to VectorRangeQuery kwargs. + + Args: + vector: The query vector embedding. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for VectorRangeQuery constructor. + """ + kwargs: dict[str, Any] = { + "vector": vector, + "vector_field_name": self.vector_field_name, + "distance_threshold": self.distance_threshold, + "num_results": self.num_results, + "dtype": self.dtype, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "normalize_vector_distance": self.normalize_vector_distance, + "filter_expression": filter_expression, + } + + # Version-dependent: only include if not None + if self.epsilon is not None: + kwargs["epsilon"] = self.epsilon + + return kwargs + + +class RedisTextQueryConfig(BaseModel): + """Configuration for Redis full-text search queries. + + This config groups all query-specific parameters for TextQuery, + using BM25 scoring for keyword-based search. + + Attributes: + text_field_name: Name of the text field to search. + text_scorer: Text scoring algorithm (default: "BM25STD"). + num_results: Number of results to return (default: 10). + return_score: Whether to return the text score. + dialect: RediSearch query dialect version. + sort_by: Field(s) to order results by. + in_order: Require query terms in same order as document. + stopwords: Stopwords to remove from query (default: "english"). + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + text_field_name: str = Field(default="content") + text_scorer: str = Field(default="BM25STD") + num_results: int = Field(default=10, ge=1) + return_score: bool = Field(default=True) + dialect: int = Field(default=2, ge=1) + sort_by: SortSpec = Field(default=None) + in_order: bool = Field(default=False) + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + + def to_query_kwargs( + self, + text: str, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to TextQuery kwargs. + + Args: + text: The query text for BM25 matching. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for TextQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "text_scorer": self.text_scorer, + "num_results": self.num_results, + "return_score": self.return_score, + "dialect": self.dialect, + "sort_by": self.sort_by, + "in_order": self.in_order, + "stopwords": self.stopwords, + "return_fields": return_fields, + "filter_expression": filter_expression, + } + + +class RedisHybridQueryConfig(BaseModel): + """Configuration for Redis hybrid search queries. + + Hybrid search combines semantic vector similarity with keyword-based + BM25 text matching using Redis's native FT.HYBRID command. + + Requirements: + - Redis >= 8.4.0 (for native FT.HYBRID command support) + - redis-py >= 7.1.0 + + Attributes: + text_field_name: Name of the text field for BM25 search. + vector_field_name: Name of the vector field for similarity search. + vector_param_name: Name of the parameter substitution for vector blob. + text_scorer: Text scoring algorithm (default: "BM25STD"). + yield_text_score_as: Field name to yield the text score as. + vector_search_method: Vector search method - "KNN" or "RANGE". + knn_ef_runtime: Exploration factor for HNSW when using KNN. + range_radius: Search radius when using RANGE vector search. + range_epsilon: Epsilon for RANGE search accuracy. + yield_vsim_score_as: Field name to yield vector similarity score as. + combination_method: Score combination method - "RRF" or "LINEAR". + linear_alpha: Weight of text score when using LINEAR. + rrf_window: Window size for RRF combination. + rrf_constant: Constant for RRF combination. + yield_combined_score_as: Field name to yield combined score as. + num_results: Number of results to return. + dtype: Data type of the vector. + stopwords: Stopwords to remove from query. + text_weights: Optional field weights for text scoring. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + # Text search parameters + text_field_name: str = Field(default="content") + text_scorer: str = Field(default="BM25STD") + yield_text_score_as: Optional[str] = Field(default=None) + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + text_weights: Optional[Dict[str, float]] = Field(default=None) + + # Vector search parameters + vector_field_name: str = Field(default="embedding") + vector_param_name: str = Field(default="vector") + vector_search_method: Optional[str] = Field(default=None) + knn_ef_runtime: int = Field(default=10, ge=1) + range_radius: Optional[float] = Field(default=None) + range_epsilon: float = Field(default=0.01, ge=0.0) + yield_vsim_score_as: Optional[str] = Field(default=None) + dtype: str = Field(default="float32") + + # Score combination parameters + combination_method: Optional[str] = Field(default=None) + linear_alpha: float = Field(default=0.3, ge=0.0, le=1.0) + rrf_window: int = Field(default=20, ge=1) + rrf_constant: int = Field(default=60, ge=1) + yield_combined_score_as: Optional[str] = Field(default=None) + + # Result parameters + num_results: int = Field(default=10, ge=1) + + def to_query_kwargs( + self, + text: str, + vector: List[float], + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to HybridQuery kwargs. + + Args: + text: The query text for BM25 matching. + vector: The query vector embedding. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for HybridQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "vector": vector, + "vector_field_name": self.vector_field_name, + "vector_param_name": self.vector_param_name, + "text_scorer": self.text_scorer, + "yield_text_score_as": self.yield_text_score_as, + "vector_search_method": self.vector_search_method, + "knn_ef_runtime": self.knn_ef_runtime, + "range_radius": self.range_radius, + "range_epsilon": self.range_epsilon, + "yield_vsim_score_as": self.yield_vsim_score_as, + "filter_expression": filter_expression, + "combination_method": self.combination_method, + "rrf_window": self.rrf_window, + "rrf_constant": self.rrf_constant, + "linear_alpha": self.linear_alpha, + "yield_combined_score_as": self.yield_combined_score_as, + "dtype": self.dtype, + "num_results": self.num_results, + "return_fields": return_fields, + "stopwords": self.stopwords, + "text_weights": self.text_weights, + } + From 8e42778179252a50d6a5ac9023b0cfc63a59393e Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Fri, 9 Jan 2026 19:49:57 +0200 Subject: [PATCH 09/16] refactor(redis): use config objects in Redis search tool constructors Refactor all four Redis search tools to accept config objects: - RedisVectorSearchTool: 22 params -> 7 params - RedisHybridSearchTool: 26 params -> 7 params - RedisRangeSearchTool: 14 params -> 7 params - RedisTextSearchTool: 12 params -> 6 params Benefits: - Reduces constructor parameter explosion - Groups query-specific parameters logically - Follows BigQueryToolConfig and OpenMemoryServiceConfig patterns - Enables Pydantic validation of parameters --- .../tools/redis/hybrid_search_tool.py | 112 ++++------------- .../tools/redis/range_search_tool.py | 82 ++++--------- .../tools/redis/text_search_tool.py | 70 ++++------- .../tools/redis/vector_search_tool.py | 116 +++++------------- 4 files changed, 104 insertions(+), 276 deletions(-) diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index df395f7..621fe9c 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -17,10 +17,8 @@ from __future__ import annotations from typing import Any -from typing import Dict from typing import List from typing import Optional -from typing import Set from typing import Union from google.genai import types @@ -30,6 +28,7 @@ from redisvl.utils.vectorize import BaseVectorizer from .base_search_tool import VectorizedSearchTool +from .config import RedisHybridQueryConfig class RedisHybridSearchTool(VectorizedSearchTool): @@ -48,17 +47,24 @@ class RedisHybridSearchTool(VectorizedSearchTool): ```python from redisvl.index import SearchIndex from redisvl.utils.vectorize import HFTextVectorizer - from google.adk_community.tools.redis import RedisHybridSearchTool + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisHybridQueryConfig, + ) index = SearchIndex.from_yaml("schema.yaml") vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") - tool = RedisHybridSearchTool( - index=index, - vectorizer=vectorizer, + # Using config object (recommended) + config = RedisHybridQueryConfig( text_field_name="content", linear_alpha=0.7, # 70% text, 30% vector num_results=10, + ) + tool = RedisHybridSearchTool( + index=index, + vectorizer=vectorizer, + config=config, return_fields=["title", "content"], ) @@ -71,27 +77,9 @@ def __init__( *, index: Union[SearchIndex, AsyncSearchIndex], vectorizer: BaseVectorizer, - text_field_name: str = "content", - vector_field_name: str = "embedding", - vector_param_name: str = "vector", - text_scorer: str = "BM25STD", - yield_text_score_as: Optional[str] = None, - vector_search_method: Optional[str] = None, - knn_ef_runtime: int = 10, - range_radius: Optional[float] = None, - range_epsilon: float = 0.01, - yield_vsim_score_as: Optional[str] = None, - combination_method: Optional[str] = None, - linear_alpha: float = 0.3, - rrf_window: int = 20, - rrf_constant: int = 60, - yield_combined_score_as: Optional[str] = None, - num_results: int = 10, + config: Optional[RedisHybridQueryConfig] = None, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, - dtype: str = "float32", - stopwords: Optional[Union[str, Set[str]]] = "english", - text_weights: Optional[Dict[str, float]] = None, name: str = "redis_hybrid_search", description: str = "Search using both semantic similarity and keyword matching.", ): @@ -100,27 +88,11 @@ def __init__( Args: index: The RedisVL SearchIndex or AsyncSearchIndex to query. vectorizer: The vectorizer for embedding queries. - text_field_name: The name of the text field for BM25 search. - vector_field_name: The name of the vector field for similarity search. - vector_param_name: Name of the parameter substitution for vector blob. - text_scorer: The text scoring algorithm (default: "BM25STD"). - yield_text_score_as: Field name to yield the text score as. - vector_search_method: Vector search method - "KNN" or "RANGE". - knn_ef_runtime: Exploration factor for HNSW when using KNN (default: 10). - range_radius: Search radius when using RANGE vector search. - range_epsilon: Epsilon for RANGE search accuracy (default: 0.01). - yield_vsim_score_as: Field name to yield the vector similarity score as. - combination_method: Score combination method - "RRF" or "LINEAR". - linear_alpha: Weight of text score when using LINEAR (default: 0.3). - rrf_window: Window size for RRF combination (default: 20). - rrf_constant: Constant for RRF combination (default: 60). - yield_combined_score_as: Field name to yield the combined score as. - num_results: Default number of results to return (default: 10). + config: Configuration for query parameters. If None, uses defaults. + See RedisHybridQueryConfig for available options including + text_field_name, vector_field_name, linear_alpha, and more. return_fields: Optional list of fields to return in results. filter_expression: Optional filter expression to narrow results. - dtype: The dtype of the vector (default: "float32"). - stopwords: Stopwords to remove from query (default: "english"). - text_weights: Optional field weights for text scoring. name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). """ @@ -131,26 +103,8 @@ def __init__( vectorizer=vectorizer, return_fields=return_fields, ) - self._text_field_name = text_field_name - self._vector_field_name = vector_field_name - self._vector_param_name = vector_param_name - self._text_scorer = text_scorer - self._yield_text_score_as = yield_text_score_as - self._vector_search_method = vector_search_method - self._knn_ef_runtime = knn_ef_runtime - self._range_radius = range_radius - self._range_epsilon = range_epsilon - self._yield_vsim_score_as = yield_vsim_score_as - self._combination_method = combination_method - self._linear_alpha = linear_alpha - self._rrf_window = rrf_window - self._rrf_constant = rrf_constant - self._yield_combined_score_as = yield_combined_score_as - self._num_results = num_results + self._config = config or RedisHybridQueryConfig() self._filter_expression = filter_expression - self._dtype = dtype - self._stopwords = stopwords - self._text_weights = text_weights def _get_declaration(self) -> types.FunctionDeclaration: """Get the function declaration for the LLM.""" @@ -168,7 +122,7 @@ def _get_declaration(self) -> types.FunctionDeclaration: type=types.Type.INTEGER, description=( "Number of results to return (default:" - f" {self._num_results})." + f" {self._config.num_results})." ), ), }, @@ -189,30 +143,16 @@ def _build_query( Returns: A HybridQuery configured for hybrid search. """ - num_results = kwargs.get("num_results", self._num_results) + # Allow runtime override of num_results + num_results = kwargs.get("num_results", self._config.num_results) - return HybridQuery( + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( text=query_text, - text_field_name=self._text_field_name, vector=embedding, - vector_field_name=self._vector_field_name, - vector_param_name=self._vector_param_name, - text_scorer=self._text_scorer, - yield_text_score_as=self._yield_text_score_as, - vector_search_method=self._vector_search_method, - knn_ef_runtime=self._knn_ef_runtime, - range_radius=self._range_radius, - range_epsilon=self._range_epsilon, - yield_vsim_score_as=self._yield_vsim_score_as, - filter_expression=self._filter_expression, - combination_method=self._combination_method, - rrf_window=self._rrf_window, - rrf_constant=self._rrf_constant, - linear_alpha=self._linear_alpha, - yield_combined_score_as=self._yield_combined_score_as, - dtype=self._dtype, - num_results=num_results, return_fields=self._return_fields, - stopwords=self._stopwords, - text_weights=self._text_weights, + filter_expression=self._filter_expression, ) + query_kwargs["num_results"] = num_results + + return HybridQuery(**query_kwargs) diff --git a/src/google/adk_community/tools/redis/range_search_tool.py b/src/google/adk_community/tools/redis/range_search_tool.py index 179c8a6..dac714e 100644 --- a/src/google/adk_community/tools/redis/range_search_tool.py +++ b/src/google/adk_community/tools/redis/range_search_tool.py @@ -19,7 +19,6 @@ from typing import Any from typing import List from typing import Optional -from typing import Tuple from typing import Union from google.genai import types @@ -29,11 +28,7 @@ from redisvl.utils.vectorize import BaseVectorizer from .base_search_tool import VectorizedSearchTool - -# Type alias for sort specification -SortSpec = Optional[ - Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] -] +from .config import RedisRangeQueryConfig class RedisRangeSearchTool(VectorizedSearchTool): @@ -48,15 +43,22 @@ class RedisRangeSearchTool(VectorizedSearchTool): ```python from redisvl.index import SearchIndex from redisvl.utils.vectorize import HFTextVectorizer - from google.adk_community.tools.redis import RedisRangeSearchTool + from google.adk_community.tools.redis import ( + RedisRangeSearchTool, + RedisRangeQueryConfig, + ) index = SearchIndex.from_yaml("schema.yaml") vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + # Using config object (recommended) + config = RedisRangeQueryConfig( + distance_threshold=0.3, # Only return docs within 0.3 distance + ) tool = RedisRangeSearchTool( index=index, vectorizer=vectorizer, - distance_threshold=0.3, # Only return docs within 0.3 distance + config=config, return_fields=["title", "content"], ) @@ -69,18 +71,9 @@ def __init__( *, index: Union[SearchIndex, AsyncSearchIndex], vectorizer: BaseVectorizer, - vector_field_name: str = "embedding", - distance_threshold: float = 0.2, - num_results: int = 10, + config: Optional[RedisRangeQueryConfig] = None, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, - dtype: str = "float32", - return_score: bool = True, - dialect: int = 2, - sort_by: SortSpec = None, - in_order: bool = False, - epsilon: Optional[float] = None, - normalize_vector_distance: bool = False, name: str = "redis_range_search", description: str = "Find all documents within a similarity threshold.", ): @@ -89,18 +82,11 @@ def __init__( Args: index: The RedisVL SearchIndex or AsyncSearchIndex to query. vectorizer: The vectorizer for embedding queries. - vector_field_name: The name of the vector field in the index. - distance_threshold: Maximum distance for results (default: 0.2). - num_results: Maximum number of results to return (default: 10). + config: Configuration for query parameters. If None, uses defaults. + See RedisRangeQueryConfig for available options including + distance_threshold, vector_field_name, and epsilon. return_fields: Optional list of fields to return in results. filter_expression: Optional filter expression to narrow results. - dtype: The dtype of the vector (default: "float32"). - return_score: Whether to return the vector distance (default: True). - dialect: The RediSearch query dialect (default: 2). - sort_by: Field(s) to order results by. - in_order: Require query terms in same order (default: False). - epsilon: Range search approximation factor for HNSW/SVS-VAMANA. - normalize_vector_distance: Convert distance to 0-1 similarity. name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). """ @@ -111,17 +97,8 @@ def __init__( vectorizer=vectorizer, return_fields=return_fields, ) - self._vector_field_name = vector_field_name - self._distance_threshold = distance_threshold - self._num_results = num_results + self._config = config or RedisRangeQueryConfig() self._filter_expression = filter_expression - self._dtype = dtype - self._return_score = return_score - self._dialect = dialect - self._sort_by = sort_by - self._in_order = in_order - self._epsilon = epsilon - self._normalize_vector_distance = normalize_vector_distance def _get_declaration(self) -> types.FunctionDeclaration: """Get the function declaration for the LLM.""" @@ -139,7 +116,7 @@ def _get_declaration(self) -> types.FunctionDeclaration: type=types.Type.NUMBER, description=( "Max distance threshold (default:" - f" {self._distance_threshold})." + f" {self._config.distance_threshold})." ), ), }, @@ -160,26 +137,17 @@ def _build_query( Returns: A VectorRangeQuery configured for range search. """ + # Allow runtime override of distance_threshold distance_threshold = kwargs.get( - "distance_threshold", self._distance_threshold + "distance_threshold", self._config.distance_threshold ) - query_kwargs: dict[str, Any] = { - "vector": embedding, - "vector_field_name": self._vector_field_name, - "distance_threshold": distance_threshold, - "num_results": self._num_results, - "return_fields": self._return_fields, - "filter_expression": self._filter_expression, - "dtype": self._dtype, - "return_score": self._return_score, - "dialect": self._dialect, - "sort_by": self._sort_by, - "in_order": self._in_order, - "normalize_vector_distance": self._normalize_vector_distance, - } - - if self._epsilon is not None: - query_kwargs["epsilon"] = self._epsilon + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( + vector=embedding, + filter_expression=self._filter_expression, + ) + query_kwargs["return_fields"] = self._return_fields + query_kwargs["distance_threshold"] = distance_threshold return VectorRangeQuery(**query_kwargs) diff --git a/src/google/adk_community/tools/redis/text_search_tool.py b/src/google/adk_community/tools/redis/text_search_tool.py index ba7c24f..8d719b6 100644 --- a/src/google/adk_community/tools/redis/text_search_tool.py +++ b/src/google/adk_community/tools/redis/text_search_tool.py @@ -20,8 +20,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Set -from typing import Tuple from typing import Union from google.adk.tools.tool_context import ToolContext @@ -31,11 +29,7 @@ from redisvl.query import TextQuery from .base_search_tool import BaseRedisSearchTool - -# Type alias for sort specification -SortSpec = Optional[ - Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] -] +from .config import RedisTextQueryConfig class RedisTextSearchTool(BaseRedisSearchTool): @@ -48,14 +42,22 @@ class RedisTextSearchTool(BaseRedisSearchTool): Example: ```python from redisvl.index import SearchIndex - from google.adk_community.tools.redis import RedisTextSearchTool + from google.adk_community.tools.redis import ( + RedisTextSearchTool, + RedisTextQueryConfig, + ) index = SearchIndex.from_yaml("schema.yaml") - tool = RedisTextSearchTool( - index=index, + # Using config object (recommended) + config = RedisTextQueryConfig( text_field_name="content", num_results=10, + text_scorer="BM25STD", + ) + tool = RedisTextSearchTool( + index=index, + config=config, return_fields=["title", "content"], ) @@ -67,16 +69,9 @@ def __init__( self, *, index: Union[SearchIndex, AsyncSearchIndex], - text_field_name: str = "content", - text_scorer: str = "BM25STD", - num_results: int = 10, + config: Optional[RedisTextQueryConfig] = None, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, - return_score: bool = True, - dialect: int = 2, - sort_by: SortSpec = None, - in_order: bool = False, - stopwords: Optional[Union[str, Set[str]]] = "english", name: str = "redis_text_search", description: str = "Search for documents using keyword matching.", ): @@ -84,16 +79,10 @@ def __init__( Args: index: The RedisVL SearchIndex or AsyncSearchIndex to query. - text_field_name: The name of the text field to search. - text_scorer: The text scoring algorithm (default: "BM25STD"). - num_results: Default number of results to return (default: 10). + config: Configuration for text query parameters. If not provided, + defaults will be used. return_fields: Optional list of fields to return in results. filter_expression: Optional filter expression to narrow results. - return_score: Whether to return the text score (default: True). - dialect: The RediSearch query dialect (default: 2). - sort_by: Field(s) to order results by. - in_order: Require query terms in same order (default: False). - stopwords: Stopwords to remove from query (default: "english"). name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). """ @@ -103,15 +92,8 @@ def __init__( index=index, return_fields=return_fields, ) - self._text_field_name = text_field_name - self._text_scorer = text_scorer - self._num_results = num_results + self._config = config or RedisTextQueryConfig() self._filter_expression = filter_expression - self._return_score = return_score - self._dialect = dialect - self._sort_by = sort_by - self._in_order = in_order - self._stopwords = stopwords def _get_declaration(self) -> types.FunctionDeclaration: """Get the function declaration for the LLM.""" @@ -129,7 +111,7 @@ def _get_declaration(self) -> types.FunctionDeclaration: type=types.Type.INTEGER, description=( "Number of results to return (default:" - f" {self._num_results})." + f" {self._config.num_results})." ), ), }, @@ -153,19 +135,15 @@ async def run_async( async def build_query_fn( query_text: str, args: Dict[str, Any] ) -> TextQuery: - num_results = args.get("num_results", self._num_results) - return TextQuery( + # Get query kwargs from config + query_kwargs = self._config.to_query_kwargs( text=query_text, - text_field_name=self._text_field_name, - text_scorer=self._text_scorer, - filter_expression=self._filter_expression, return_fields=self._return_fields, - num_results=num_results, - return_score=self._return_score, - dialect=self._dialect, - sort_by=self._sort_by, - in_order=self._in_order, - stopwords=self._stopwords, + filter_expression=self._filter_expression, ) + # Allow LLM to override num_results + if "num_results" in args: + query_kwargs["num_results"] = args["num_results"] + return TextQuery(**query_kwargs) return await self._run_search(args, build_query_fn) diff --git a/src/google/adk_community/tools/redis/vector_search_tool.py b/src/google/adk_community/tools/redis/vector_search_tool.py index f9471d1..a000dcf 100644 --- a/src/google/adk_community/tools/redis/vector_search_tool.py +++ b/src/google/adk_community/tools/redis/vector_search_tool.py @@ -19,7 +19,6 @@ from typing import Any from typing import List from typing import Optional -from typing import Tuple from typing import Union from google.genai import types @@ -29,11 +28,7 @@ from redisvl.utils.vectorize import BaseVectorizer from .base_search_tool import VectorizedSearchTool - -# Type alias for sort specification -SortSpec = Optional[ - Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] -] +from .config import RedisVectorQueryConfig class RedisVectorSearchTool(VectorizedSearchTool): @@ -48,17 +43,25 @@ class RedisVectorSearchTool(VectorizedSearchTool): from redisvl.index import SearchIndex from redisvl.utils.vectorize import HFTextVectorizer from redisvl.query.filter import Tag - from google.adk_community.tools.redis import RedisVectorSearchTool + from google.adk_community.tools.redis import ( + RedisVectorSearchTool, + RedisVectorQueryConfig, + ) index = SearchIndex.from_yaml("schema.yaml") vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + # Using config object (recommended) + config = RedisVectorQueryConfig( + num_results=5, + ef_runtime=100, # Higher = better recall + ) tool = RedisVectorSearchTool( index=index, vectorizer=vectorizer, - num_results=5, + config=config, return_fields=["title", "content", "url"], - filter_expression=Tag("category") == "redis", # Optional filter + filter_expression=Tag("category") == "redis", ) # Use with an agent @@ -71,23 +74,9 @@ def __init__( *, index: Union[SearchIndex, AsyncSearchIndex], vectorizer: BaseVectorizer, - vector_field_name: str = "embedding", - num_results: int = 10, + config: Optional[RedisVectorQueryConfig] = None, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, - dtype: str = "float32", - return_score: bool = True, - dialect: int = 2, - sort_by: SortSpec = None, - in_order: bool = False, - hybrid_policy: Optional[str] = None, - batch_size: Optional[int] = None, - ef_runtime: Optional[int] = None, - epsilon: Optional[float] = None, - search_window_size: Optional[int] = None, - use_search_history: Optional[str] = None, - search_buffer_capacity: Optional[int] = None, - normalize_vector_distance: bool = False, name: str = "redis_vector_search", description: str = "Search for semantically similar documents using vector similarity with Redis.", ): @@ -96,23 +85,12 @@ def __init__( Args: index: The RedisVL SearchIndex to query. vectorizer: The vectorizer for embedding queries. - vector_field_name: The name of the vector field in the index. - num_results: Default number of results to return (default: 10). + config: Configuration for query parameters. If None, uses defaults. + See RedisVectorQueryConfig for available options including + num_results, vector_field_name, dtype, and version-dependent + parameters like ef_runtime and hybrid_policy. return_fields: Optional list of fields to return in results. filter_expression: Optional RedisVL FilterExpression to narrow results. - dtype: The dtype of the vector (default: "float32"). - return_score: Whether to return the vector distance (default: True). - dialect: The RediSearch query dialect (default: 2). - sort_by: Field(s) to order results by. Can be str, tuple, or list. - in_order: Require query terms in same order as document (default: False). - hybrid_policy: Filter application policy - "BATCHES" or "ADHOC_BF". - batch_size: Batch size when hybrid_policy is "BATCHES". - ef_runtime: HNSW exploration factor at query time (higher = better recall). - epsilon: Range search approximation factor for HNSW/SVS-VAMANA indexes. - search_window_size: SVS-VAMANA search window size (higher = better recall). - use_search_history: SVS-VAMANA history mode - "OFF", "ON", or "AUTO". - search_buffer_capacity: SVS-VAMANA 2-level compression tuning parameter. - normalize_vector_distance: Convert distance to similarity score 0-1 (default: False). name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). """ @@ -123,22 +101,8 @@ def __init__( vectorizer=vectorizer, return_fields=return_fields, ) - self._vector_field_name = vector_field_name - self._num_results = num_results + self._config = config or RedisVectorQueryConfig() self._filter_expression = filter_expression - self._dtype = dtype - self._return_score = return_score - self._dialect = dialect - self._sort_by = sort_by - self._in_order = in_order - self._hybrid_policy = hybrid_policy - self._batch_size = batch_size - self._ef_runtime = ef_runtime - self._epsilon = epsilon - self._search_window_size = search_window_size - self._use_search_history = use_search_history - self._search_buffer_capacity = search_buffer_capacity - self._normalize_vector_distance = normalize_vector_distance def _get_declaration(self) -> types.FunctionDeclaration: """Get the function declaration for the LLM.""" @@ -156,7 +120,7 @@ def _get_declaration(self) -> types.FunctionDeclaration: type=types.Type.INTEGER, description=( "Number of results to return (default:" - f" {self._num_results})." + f" {self._config.num_results})." ), ), }, @@ -177,37 +141,15 @@ def _build_query( Returns: A VectorQuery configured for KNN search. """ - num_results = kwargs.get("num_results", self._num_results) - - # Build query kwargs, only including optional params if set - query_kwargs: dict[str, Any] = { - "vector": embedding, - "vector_field_name": self._vector_field_name, - "num_results": num_results, - "return_fields": self._return_fields, - "filter_expression": self._filter_expression, - "dtype": self._dtype, - "return_score": self._return_score, - "dialect": self._dialect, - "sort_by": self._sort_by, - "in_order": self._in_order, - "normalize_vector_distance": self._normalize_vector_distance, - } - - # Add optional parameters only if set (for version compatibility) - if self._hybrid_policy is not None: - query_kwargs["hybrid_policy"] = self._hybrid_policy - if self._batch_size is not None: - query_kwargs["batch_size"] = self._batch_size - if self._ef_runtime is not None: - query_kwargs["ef_runtime"] = self._ef_runtime - if self._epsilon is not None: - query_kwargs["epsilon"] = self._epsilon - if self._search_window_size is not None: - query_kwargs["search_window_size"] = self._search_window_size - if self._use_search_history is not None: - query_kwargs["use_search_history"] = self._use_search_history - if self._search_buffer_capacity is not None: - query_kwargs["search_buffer_capacity"] = self._search_buffer_capacity + # Allow runtime override of num_results + num_results = kwargs.get("num_results", self._config.num_results) + + # Get query kwargs from config, with runtime overrides + query_kwargs = self._config.to_query_kwargs( + vector=embedding, + filter_expression=self._filter_expression, + ) + query_kwargs["return_fields"] = self._return_fields + query_kwargs["num_results"] = num_results return VectorQuery(**query_kwargs) From cb26dbbd57aef3e79f08f200fe8044f3f98ad2d8 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Fri, 9 Jan 2026 19:50:04 +0200 Subject: [PATCH 10/16] feat(redis): export config classes from tools module Add exports for all Redis config classes: - RedisVectorQueryConfig - RedisHybridQueryConfig - RedisRangeQueryConfig - RedisTextQueryConfig Supports lazy loading pattern for optional redisvl dependency. --- src/google/adk_community/tools/__init__.py | 24 +++++++++++++++++++ .../adk_community/tools/redis/__init__.py | 23 ++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py index 3c667ed..4000eb7 100644 --- a/src/google/adk_community/tools/__init__.py +++ b/src/google/adk_community/tools/__init__.py @@ -25,6 +25,9 @@ "RedisHybridSearchTool", "RedisRangeSearchTool", "RedisTextSearchTool", + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisRangeQueryConfig", ] # Redis tool names for lazy loading @@ -37,6 +40,14 @@ "RedisTextSearchTool", } +# Redis config names for lazy loading +_REDIS_CONFIGS = { + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisRangeQueryConfig", + "RedisTextQueryConfig", +} + def __getattr__(name: str): """Lazy load tools to provide helpful error messages.""" @@ -63,4 +74,17 @@ def __getattr__(name: str): f"{name} requires redisvl. " "Install with: pip install google-adk-community[redis-vl]" ) from e + if name in _REDIS_CONFIGS: + from .redis import RedisHybridQueryConfig + from .redis import RedisRangeQueryConfig + from .redis import RedisTextQueryConfig + from .redis import RedisVectorQueryConfig + + globals().update({ + "RedisVectorQueryConfig": RedisVectorQueryConfig, + "RedisHybridQueryConfig": RedisHybridQueryConfig, + "RedisRangeQueryConfig": RedisRangeQueryConfig, + "RedisTextQueryConfig": RedisTextQueryConfig, + }) + return globals()[name] raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/src/google/adk_community/tools/redis/__init__.py b/src/google/adk_community/tools/redis/__init__.py index 45b8e8f..cd8dd64 100644 --- a/src/google/adk_community/tools/redis/__init__.py +++ b/src/google/adk_community/tools/redis/__init__.py @@ -21,19 +21,30 @@ - `RedisRangeSearchTool`: Distance threshold-based vector search - `RedisTextSearchTool`: Full-text BM25 keyword search +Configuration classes for query parameters: + +- `RedisVectorQueryConfig`: Configuration for vector search queries +- `RedisHybridQueryConfig`: Configuration for hybrid search queries +- `RedisRangeQueryConfig`: Configuration for range search queries +- `RedisTextQueryConfig`: Configuration for text search queries + Example: ```python from redisvl.index import SearchIndex from redisvl.utils.vectorize import HFTextVectorizer - from google.adk_community.tools.redis import RedisVectorSearchTool + from google.adk_community.tools.redis import ( + RedisVectorSearchTool, + RedisVectorQueryConfig, + ) index = SearchIndex.from_yaml("schema.yaml") vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") + config = RedisVectorQueryConfig(num_results=5, ef_runtime=100) tool = RedisVectorSearchTool( index=index, vectorizer=vectorizer, - num_results=5, + config=config, ) ``` """ @@ -41,6 +52,10 @@ try: from .base_search_tool import BaseRedisSearchTool from .base_search_tool import VectorizedSearchTool + from .config import RedisHybridQueryConfig + from .config import RedisRangeQueryConfig + from .config import RedisTextQueryConfig + from .config import RedisVectorQueryConfig from .hybrid_search_tool import RedisHybridSearchTool from .range_search_tool import RedisRangeSearchTool from .text_search_tool import RedisTextSearchTool @@ -58,4 +73,8 @@ "RedisHybridSearchTool", "RedisRangeSearchTool", "RedisTextSearchTool", + "RedisVectorQueryConfig", + "RedisHybridQueryConfig", + "RedisRangeQueryConfig", + "RedisTextQueryConfig", ] From b4e81d00cf0f72d48943d6d27ffc7307d6d2454a Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Fri, 9 Jan 2026 19:50:12 +0200 Subject: [PATCH 11/16] test(redis): update tests to use config-based API Update all Redis search tool tests to use config objects: - test_vector_search_tool.py - test_hybrid_search_tool.py - test_range_search_tool.py - test_text_search_tool.py Tests now verify config object properties instead of direct attributes. --- .../tools/redis/test_hybrid_search_tool.py | 67 ++++++---- .../tools/redis/test_range_search_tool.py | 65 +++++---- .../tools/redis/test_text_search_tool.py | 55 ++++---- .../tools/redis/test_vector_search_tool.py | 125 ++++++++++-------- 4 files changed, 175 insertions(+), 137 deletions(-) diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py index 85fe1c9..08de179 100644 --- a/tests/unittests/tools/redis/test_hybrid_search_tool.py +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -26,6 +26,7 @@ from redisvl.query import HybridQuery from redisvl.utils.vectorize import BaseVectorizer +from google.adk_community.tools.redis import RedisHybridQueryConfig from google.adk_community.tools.redis import RedisHybridSearchTool @@ -53,11 +54,14 @@ def mock_index(): @pytest.fixture def hybrid_search_tool(mock_index, mock_vectorizer): """Create RedisHybridSearchTool instance for testing.""" + config = RedisHybridQueryConfig( + text_field_name="content", + num_results=5, + ) return RedisHybridSearchTool( index=mock_index, vectorizer=mock_vectorizer, - text_field_name="content", - num_results=5, + config=config, return_fields=["title", "content"], ) @@ -66,28 +70,28 @@ class TestRedisHybridSearchToolInit: """Tests for RedisHybridSearchTool initialization.""" def test_default_parameters(self, mock_index, mock_vectorizer): - """Test default parameter values.""" + """Test default parameter values with default config.""" tool = RedisHybridSearchTool( index=mock_index, vectorizer=mock_vectorizer, ) - assert tool._text_field_name == "content" - assert tool._vector_field_name == "embedding" - assert tool._text_scorer == "BM25STD" - assert tool._combination_method is None - assert tool._linear_alpha == 0.3 - assert tool._rrf_window == 20 - assert tool._rrf_constant == 60 - assert tool._num_results == 10 + # Config defaults + assert tool._config.text_field_name == "content" + assert tool._config.vector_field_name == "embedding" + assert tool._config.text_scorer == "BM25STD" + assert tool._config.combination_method is None + assert tool._config.linear_alpha == 0.3 + assert tool._config.rrf_window == 20 + assert tool._config.rrf_constant == 60 + assert tool._config.num_results == 10 + assert tool._config.dtype == "float32" + assert tool._config.stopwords == "english" + # Tool-level defaults assert tool._filter_expression is None - assert tool._dtype == "float32" - assert tool._stopwords == "english" - def test_custom_parameters(self, mock_index, mock_vectorizer): - """Test custom parameter values.""" - tool = RedisHybridSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, + def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): + """Test custom parameter values via config object.""" + config = RedisHybridQueryConfig( text_field_name="description", vector_field_name="vec", text_scorer="TFIDF", @@ -96,21 +100,26 @@ def test_custom_parameters(self, mock_index, mock_vectorizer): rrf_window=30, rrf_constant=80, num_results=20, - return_fields=["title", "url"], dtype="float64", stopwords={"the", "a", "an"}, ) - assert tool._text_field_name == "description" - assert tool._vector_field_name == "vec" - assert tool._text_scorer == "TFIDF" - assert tool._combination_method == "LINEAR" - assert tool._linear_alpha == 0.7 - assert tool._rrf_window == 30 - assert tool._rrf_constant == 80 - assert tool._num_results == 20 + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "url"], + ) + assert tool._config.text_field_name == "description" + assert tool._config.vector_field_name == "vec" + assert tool._config.text_scorer == "TFIDF" + assert tool._config.combination_method == "LINEAR" + assert tool._config.linear_alpha == 0.7 + assert tool._config.rrf_window == 30 + assert tool._config.rrf_constant == 80 + assert tool._config.num_results == 20 assert tool._return_fields == ["title", "url"] - assert tool._dtype == "float64" - assert tool._stopwords == {"the", "a", "an"} + assert tool._config.dtype == "float64" + assert tool._config.stopwords == {"the", "a", "an"} def test_custom_name_and_description(self, mock_index, mock_vectorizer): """Test custom tool name and description.""" diff --git a/tests/unittests/tools/redis/test_range_search_tool.py b/tests/unittests/tools/redis/test_range_search_tool.py index db0173c..a70f683 100644 --- a/tests/unittests/tools/redis/test_range_search_tool.py +++ b/tests/unittests/tools/redis/test_range_search_tool.py @@ -26,6 +26,7 @@ from redisvl.query import VectorRangeQuery from redisvl.utils.vectorize import BaseVectorizer +from google.adk_community.tools.redis import RedisRangeQueryConfig from google.adk_community.tools.redis import RedisRangeSearchTool @@ -55,11 +56,14 @@ def mock_index(): @pytest.fixture def range_search_tool(mock_index, mock_vectorizer): """Create RedisRangeSearchTool instance for testing.""" + config = RedisRangeQueryConfig( + distance_threshold=0.3, + num_results=5, + ) return RedisRangeSearchTool( index=mock_index, vectorizer=mock_vectorizer, - distance_threshold=0.3, - num_results=5, + config=config, return_fields=["title", "content"], ) @@ -68,32 +72,31 @@ class TestRedisRangeSearchToolInit: """Tests for RedisRangeSearchTool initialization.""" def test_default_parameters(self, mock_index, mock_vectorizer): - """Test default parameter values.""" + """Test default parameter values with default config.""" tool = RedisRangeSearchTool( index=mock_index, vectorizer=mock_vectorizer, ) - assert tool._vector_field_name == "embedding" - assert tool._distance_threshold == 0.2 - assert tool._num_results == 10 - assert tool._dtype == "float32" - assert tool._return_score is True - assert tool._dialect == 2 - assert tool._in_order is False - assert tool._normalize_vector_distance is False + # Config defaults + assert tool._config.vector_field_name == "embedding" + assert tool._config.distance_threshold == 0.2 + assert tool._config.num_results == 10 + assert tool._config.dtype == "float32" + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.normalize_vector_distance is False + assert tool._config.sort_by is None + assert tool._config.epsilon is None + # Tool-level defaults assert tool._filter_expression is None - assert tool._sort_by is None - assert tool._epsilon is None - def test_custom_parameters(self, mock_index, mock_vectorizer): - """Test custom parameter values.""" - tool = RedisRangeSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, + def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): + """Test custom parameter values via config object.""" + config = RedisRangeQueryConfig( vector_field_name="vec", distance_threshold=0.5, num_results=20, - return_fields=["title", "url"], dtype="float64", return_score=False, dialect=3, @@ -101,16 +104,22 @@ def test_custom_parameters(self, mock_index, mock_vectorizer): normalize_vector_distance=True, epsilon=0.01, ) - assert tool._vector_field_name == "vec" - assert tool._distance_threshold == 0.5 - assert tool._num_results == 20 + tool = RedisRangeSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "url"], + ) + assert tool._config.vector_field_name == "vec" + assert tool._config.distance_threshold == 0.5 + assert tool._config.num_results == 20 assert tool._return_fields == ["title", "url"] - assert tool._dtype == "float64" - assert tool._return_score is False - assert tool._dialect == 3 - assert tool._in_order is True - assert tool._normalize_vector_distance is True - assert tool._epsilon == 0.01 + assert tool._config.dtype == "float64" + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.normalize_vector_distance is True + assert tool._config.epsilon == 0.01 def test_custom_name_and_description(self, mock_index, mock_vectorizer): """Test custom tool name and description.""" diff --git a/tests/unittests/tools/redis/test_text_search_tool.py b/tests/unittests/tools/redis/test_text_search_tool.py index 5144a70..1e44c12 100644 --- a/tests/unittests/tools/redis/test_text_search_tool.py +++ b/tests/unittests/tools/redis/test_text_search_tool.py @@ -24,6 +24,7 @@ from redisvl.index import SearchIndex +from google.adk_community.tools.redis import RedisTextQueryConfig from google.adk_community.tools.redis import RedisTextSearchTool @@ -42,53 +43,59 @@ def mock_index(): @pytest.fixture def text_search_tool(mock_index): """Create RedisTextSearchTool instance for testing.""" - return RedisTextSearchTool( - index=mock_index, + config = RedisTextQueryConfig( text_field_name="content", num_results=5, - return_fields=["title", "content"], stopwords=None, # Avoid nltk dependency in tests ) + return RedisTextSearchTool( + index=mock_index, + config=config, + return_fields=["title", "content"], + ) class TestRedisTextSearchToolInit: """Tests for RedisTextSearchTool initialization.""" def test_default_parameters(self, mock_index): - """Test default parameter values.""" + """Test default parameter values via config.""" tool = RedisTextSearchTool(index=mock_index) - assert tool._text_field_name == "content" - assert tool._text_scorer == "BM25STD" - assert tool._num_results == 10 - assert tool._return_score is True - assert tool._dialect == 2 - assert tool._in_order is False - assert tool._stopwords == "english" + assert tool._config.text_field_name == "content" + assert tool._config.text_scorer == "BM25STD" + assert tool._config.num_results == 10 + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.stopwords == "english" assert tool._filter_expression is None - assert tool._sort_by is None + assert tool._config.sort_by is None assert tool._return_fields is None - def test_custom_parameters(self, mock_index): - """Test custom parameter values.""" - tool = RedisTextSearchTool( - index=mock_index, + def test_custom_parameters_via_config(self, mock_index): + """Test custom parameter values via config object.""" + config = RedisTextQueryConfig( text_field_name="description", text_scorer="TFIDF", num_results=20, - return_fields=["title", "url"], return_score=False, dialect=3, in_order=True, stopwords={"the", "a", "an"}, ) - assert tool._text_field_name == "description" - assert tool._text_scorer == "TFIDF" - assert tool._num_results == 20 + tool = RedisTextSearchTool( + index=mock_index, + config=config, + return_fields=["title", "url"], + ) + assert tool._config.text_field_name == "description" + assert tool._config.text_scorer == "TFIDF" + assert tool._config.num_results == 20 assert tool._return_fields == ["title", "url"] - assert tool._return_score is False - assert tool._dialect == 3 - assert tool._in_order is True - assert tool._stopwords == {"the", "a", "an"} + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.stopwords == {"the", "a", "an"} def test_custom_name_and_description(self, mock_index): """Test custom tool name and description.""" diff --git a/tests/unittests/tools/redis/test_vector_search_tool.py b/tests/unittests/tools/redis/test_vector_search_tool.py index a3ad514..05900d9 100644 --- a/tests/unittests/tools/redis/test_vector_search_tool.py +++ b/tests/unittests/tools/redis/test_vector_search_tool.py @@ -27,6 +27,7 @@ from redisvl.query import VectorQuery from redisvl.utils.vectorize import BaseVectorizer +from google.adk_community.tools.redis import RedisVectorQueryConfig from google.adk_community.tools.redis import RedisVectorSearchTool @@ -56,10 +57,11 @@ def mock_index(): @pytest.fixture def vector_search_tool(mock_index, mock_vectorizer): """Create RedisVectorSearchTool instance for testing.""" + config = RedisVectorQueryConfig(num_results=5) return RedisVectorSearchTool( index=mock_index, vectorizer=mock_vectorizer, - num_results=5, + config=config, return_fields=["title", "content"], ) @@ -68,36 +70,35 @@ class TestRedisVectorSearchToolInit: """Tests for RedisVectorSearchTool initialization.""" def test_default_parameters(self, mock_index, mock_vectorizer): - """Test default parameter values.""" + """Test default parameter values with default config.""" tool = RedisVectorSearchTool( index=mock_index, vectorizer=mock_vectorizer, ) - assert tool._vector_field_name == "embedding" - assert tool._num_results == 10 - assert tool._dtype == "float32" - assert tool._return_score is True - assert tool._dialect == 2 - assert tool._in_order is False - assert tool._normalize_vector_distance is False + # Config defaults + assert tool._config.vector_field_name == "embedding" + assert tool._config.num_results == 10 + assert tool._config.dtype == "float32" + assert tool._config.return_score is True + assert tool._config.dialect == 2 + assert tool._config.in_order is False + assert tool._config.normalize_vector_distance is False + assert tool._config.sort_by is None + assert tool._config.hybrid_policy is None + assert tool._config.batch_size is None + assert tool._config.ef_runtime is None + assert tool._config.epsilon is None + assert tool._config.search_window_size is None + assert tool._config.use_search_history is None + assert tool._config.search_buffer_capacity is None + # Tool-level defaults assert tool._filter_expression is None - assert tool._sort_by is None - assert tool._hybrid_policy is None - assert tool._batch_size is None - assert tool._ef_runtime is None - assert tool._epsilon is None - assert tool._search_window_size is None - assert tool._use_search_history is None - assert tool._search_buffer_capacity is None - - def test_custom_parameters(self, mock_index, mock_vectorizer): - """Test custom parameter values.""" - tool = RedisVectorSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, + + def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): + """Test custom parameter values via config object.""" + config = RedisVectorQueryConfig( vector_field_name="custom_embedding", num_results=20, - return_fields=["title", "content", "url"], dtype="float64", return_score=False, dialect=3, @@ -111,21 +112,27 @@ def test_custom_parameters(self, mock_index, mock_vectorizer): use_search_history="ON", search_buffer_capacity=1000, ) - assert tool._vector_field_name == "custom_embedding" - assert tool._num_results == 20 + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title", "content", "url"], + ) + assert tool._config.vector_field_name == "custom_embedding" + assert tool._config.num_results == 20 assert tool._return_fields == ["title", "content", "url"] - assert tool._dtype == "float64" - assert tool._return_score is False - assert tool._dialect == 3 - assert tool._in_order is True - assert tool._normalize_vector_distance is True - assert tool._hybrid_policy == "BATCHES" - assert tool._batch_size == 100 - assert tool._ef_runtime == 200 - assert tool._epsilon == 0.01 - assert tool._search_window_size == 50 - assert tool._use_search_history == "ON" - assert tool._search_buffer_capacity == 1000 + assert tool._config.dtype == "float64" + assert tool._config.return_score is False + assert tool._config.dialect == 3 + assert tool._config.in_order is True + assert tool._config.normalize_vector_distance is True + assert tool._config.hybrid_policy == "BATCHES" + assert tool._config.batch_size == 100 + assert tool._config.ef_runtime == 200 + assert tool._config.epsilon == 0.01 + assert tool._config.search_window_size == 50 + assert tool._config.use_search_history == "ON" + assert tool._config.search_buffer_capacity == 1000 def test_custom_name_and_description(self, mock_index, mock_vectorizer): """Test custom tool name and description.""" @@ -169,19 +176,22 @@ def test_build_query_passes_core_parameters( # Use a string filter expression (valid type for VectorQuery) filter_str = "@category:{redis}" - tool = RedisVectorSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, + config = RedisVectorQueryConfig( vector_field_name="vec", num_results=10, - return_fields=["title"], - filter_expression=filter_str, dtype="float16", return_score=False, dialect=3, in_order=True, normalize_vector_distance=True, ) + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + return_fields=["title"], + filter_expression=filter_str, + ) embedding = [0.1] * 384 query = tool._build_query("test", embedding) @@ -194,11 +204,9 @@ def test_build_query_passes_core_parameters( assert query._dialect == 3 assert query._in_order is True - def test_stores_optional_parameters(self, mock_index, mock_vectorizer): - """Test that optional parameters are stored correctly.""" - tool = RedisVectorSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, + def test_stores_optional_parameters_in_config(self, mock_index, mock_vectorizer): + """Test that optional parameters are stored correctly in config.""" + config = RedisVectorQueryConfig( hybrid_policy="ADHOC_BF", batch_size=50, ef_runtime=100, @@ -207,15 +215,20 @@ def test_stores_optional_parameters(self, mock_index, mock_vectorizer): use_search_history="AUTO", search_buffer_capacity=500, ) + tool = RedisVectorSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=config, + ) - # Verify parameters are stored on the tool - assert tool._hybrid_policy == "ADHOC_BF" - assert tool._batch_size == 50 - assert tool._ef_runtime == 100 - assert tool._epsilon == 0.05 - assert tool._search_window_size == 25 - assert tool._use_search_history == "AUTO" - assert tool._search_buffer_capacity == 500 + # Verify parameters are stored in the config + assert tool._config.hybrid_policy == "ADHOC_BF" + assert tool._config.batch_size == 50 + assert tool._config.ef_runtime == 100 + assert tool._config.epsilon == 0.05 + assert tool._config.search_window_size == 25 + assert tool._config.use_search_history == "AUTO" + assert tool._config.search_buffer_capacity == 500 class TestRedisVectorSearchToolDeclaration: From 8bd8dc341c9c5b33bbd33a92a1775e9d1bff2358 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Fri, 9 Jan 2026 19:50:19 +0200 Subject: [PATCH 12/16] docs(samples): update redis_vl_search sample to use config-based API Update sample agent and README to use config objects: - RedisVectorQueryConfig for semantic_search - RedisTextQueryConfig for keyword_search - RedisRangeQueryConfig for range_search README examples updated to show config-based usage pattern. --- .../samples/redis_vl_search/README.md | 24 +++++++++++++------ .../redis_vl_search_agent/agent.py | 20 ++++++++++++---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/contributing/samples/redis_vl_search/README.md b/contributing/samples/redis_vl_search/README.md index 4b44178..246e960 100644 --- a/contributing/samples/redis_vl_search/README.md +++ b/contributing/samples/redis_vl_search/README.md @@ -155,11 +155,13 @@ You can add filter expressions to narrow search results: ```python from redisvl.query.filter import Tag +from google.adk_community.tools.redis import RedisVectorSearchTool, RedisVectorQueryConfig +config = RedisVectorQueryConfig(num_results=5) redis_search = RedisVectorSearchTool( index=index, vectorizer=vectorizer, - num_results=5, + config=config, return_fields=["title", "content", "url", "category"], filter_expression=Tag("category") == "redis", # Only search Redis docs ) @@ -169,20 +171,20 @@ See [RedisVL Filter documentation](https://docs.redisvl.com/api/filter.html) for ### Advanced Query Options -`RedisVectorSearchTool` exposes all VectorQuery parameters: +`RedisVectorSearchTool` uses a `RedisVectorQueryConfig` object for query parameters: ```python -redis_search = RedisVectorSearchTool( - index=index, - vectorizer=vectorizer, +from google.adk_community.tools.redis import RedisVectorSearchTool, RedisVectorQueryConfig +from redisvl.query.filter import Tag + +# Configure query parameters via config object +config = RedisVectorQueryConfig( num_results=10, - return_fields=["title", "content"], # Query tuning dtype="float32", # Vector dtype return_score=True, # Include similarity score normalize_vector_distance=True, # Convert to 0-1 similarity # Hybrid filtering - filter_expression=Tag("category") == "redis", hybrid_policy="BATCHES", # or "ADHOC_BF" batch_size=100, # For BATCHES policy # HNSW tuning @@ -193,6 +195,14 @@ redis_search = RedisVectorSearchTool( use_search_history="AUTO", # "OFF", "ON", or "AUTO" search_buffer_capacity=30, # 2-level compression tuning ) + +redis_search = RedisVectorSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + return_fields=["title", "content"], + filter_expression=Tag("category") == "redis", +) ``` See [RedisVL Query documentation](https://docs.redisvl.com/api/query.html) for details. diff --git a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py index 8a03f2f..d678e4f 100644 --- a/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py +++ b/contributing/samples/redis_vl_search/redis_vl_search_agent/agent.py @@ -20,8 +20,11 @@ from dotenv import load_dotenv from google.adk import Agent from google.adk_community.tools.redis import ( + RedisRangeQueryConfig, RedisRangeSearchTool, + RedisTextQueryConfig, RedisTextSearchTool, + RedisVectorQueryConfig, RedisVectorSearchTool, ) from redisvl.index import SearchIndex @@ -90,23 +93,30 @@ def get_index(schema_path: Path, redis_url: str) -> SearchIndex: def get_search_tools(index: SearchIndex, vectorizer: HFTextVectorizer) -> list: """Create search tools for the agent.""" + # Config objects group query-specific parameters + vector_config = RedisVectorQueryConfig(num_results=5) + text_config = RedisTextQueryConfig( + text_field_name="content", + num_results=5, + text_scorer="BM25STD", + ) + range_config = RedisRangeQueryConfig(distance_threshold=0.5) + return [ RedisVectorSearchTool( name="semantic_search", description="Semantic similarity search for conceptual queries.", index=index, vectorizer=vectorizer, - num_results=5, + config=vector_config, return_fields=RETURN_FIELDS, ), RedisTextSearchTool( name="keyword_search", description="Keyword search for exact terms and phrases.", index=index, - text_field_name="content", - num_results=5, + config=text_config, return_fields=RETURN_FIELDS, - text_scorer="BM25STD", ), RedisRangeSearchTool( name="range_search", @@ -119,7 +129,7 @@ def get_search_tools(index: SearchIndex, vectorizer: HFTextVectorizer) -> list: ), index=index, vectorizer=vectorizer, - distance_threshold=0.5, + config=range_config, return_fields=RETURN_FIELDS, ), ] From de730ed0d59ed25b18f1924f9016e76cdc6e745d Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Mon, 12 Jan 2026 13:34:52 +0200 Subject: [PATCH 13/16] feat(redis): add version-aware hybrid search with dual config support - Add RedisAggregatedHybridQueryConfig for older RedisVL versions (<0.13.0) - RedisHybridSearchTool now auto-detects RedisVL version and uses appropriate query class - Native HybridQuery (FT.HYBRID) used for RedisVL >= 0.13.0, Redis >= 8.4.0 - AggregateHybridQuery (FT.AGGREGATE) used for older versions - Emit DeprecationWarning when using aggregate config on newer versions - Raise ValueError when using native config on older versions - Use packaging.version.parse() for version comparison (aligns with ADK patterns) - Update tests to cover both code paths and version detection --- src/google/adk_community/tools/__init__.py | 5 + .../adk_community/tools/redis/__init__.py | 5 +- .../adk_community/tools/redis/config.py | 122 ++++++- .../tools/redis/hybrid_search_tool.py | 139 ++++++-- .../tools/redis/test_hybrid_search_tool.py | 316 ++++++++++++------ 5 files changed, 452 insertions(+), 135 deletions(-) diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py index 4000eb7..85ce859 100644 --- a/src/google/adk_community/tools/__init__.py +++ b/src/google/adk_community/tools/__init__.py @@ -27,7 +27,9 @@ "RedisTextSearchTool", "RedisVectorQueryConfig", "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", "RedisRangeQueryConfig", + "RedisTextQueryConfig", ] # Redis tool names for lazy loading @@ -44,6 +46,7 @@ _REDIS_CONFIGS = { "RedisVectorQueryConfig", "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", "RedisRangeQueryConfig", "RedisTextQueryConfig", } @@ -75,6 +78,7 @@ def __getattr__(name: str): "Install with: pip install google-adk-community[redis-vl]" ) from e if name in _REDIS_CONFIGS: + from .redis import RedisAggregatedHybridQueryConfig from .redis import RedisHybridQueryConfig from .redis import RedisRangeQueryConfig from .redis import RedisTextQueryConfig @@ -83,6 +87,7 @@ def __getattr__(name: str): globals().update({ "RedisVectorQueryConfig": RedisVectorQueryConfig, "RedisHybridQueryConfig": RedisHybridQueryConfig, + "RedisAggregatedHybridQueryConfig": RedisAggregatedHybridQueryConfig, "RedisRangeQueryConfig": RedisRangeQueryConfig, "RedisTextQueryConfig": RedisTextQueryConfig, }) diff --git a/src/google/adk_community/tools/redis/__init__.py b/src/google/adk_community/tools/redis/__init__.py index cd8dd64..9071f14 100644 --- a/src/google/adk_community/tools/redis/__init__.py +++ b/src/google/adk_community/tools/redis/__init__.py @@ -24,7 +24,8 @@ Configuration classes for query parameters: - `RedisVectorQueryConfig`: Configuration for vector search queries -- `RedisHybridQueryConfig`: Configuration for hybrid search queries +- `RedisHybridQueryConfig`: Configuration for native hybrid search (RedisVL >= 0.13.0) +- `RedisAggregatedHybridQueryConfig`: Configuration for client-side hybrid (older versions) - `RedisRangeQueryConfig`: Configuration for range search queries - `RedisTextQueryConfig`: Configuration for text search queries @@ -52,6 +53,7 @@ try: from .base_search_tool import BaseRedisSearchTool from .base_search_tool import VectorizedSearchTool + from .config import RedisAggregatedHybridQueryConfig from .config import RedisHybridQueryConfig from .config import RedisRangeQueryConfig from .config import RedisTextQueryConfig @@ -75,6 +77,7 @@ "RedisTextSearchTool", "RedisVectorQueryConfig", "RedisHybridQueryConfig", + "RedisAggregatedHybridQueryConfig", "RedisRangeQueryConfig", "RedisTextQueryConfig", ] diff --git a/src/google/adk_community/tools/redis/config.py b/src/google/adk_community/tools/redis/config.py index 350aa05..ddade46 100644 --- a/src/google/adk_community/tools/redis/config.py +++ b/src/google/adk_community/tools/redis/config.py @@ -253,15 +253,33 @@ def to_query_kwargs( class RedisHybridQueryConfig(BaseModel): - """Configuration for Redis hybrid search queries. + """Configuration for native Redis hybrid search queries. - Hybrid search combines semantic vector similarity with keyword-based - BM25 text matching using Redis's native FT.HYBRID command. + Uses Redis's native FT.HYBRID command for server-side hybrid search + combining semantic vector similarity with keyword-based BM25 text matching. Requirements: - - Redis >= 8.4.0 (for native FT.HYBRID command support) + - RedisVL >= 0.13.0 + - Redis >= 8.4.0 - redis-py >= 7.1.0 + For older Redis/RedisVL versions, use RedisAggregatedHybridQueryConfig instead. + + Example: + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisHybridQueryConfig, + ) + + config = RedisHybridQueryConfig( + text_field_name="content", + combination_method="LINEAR", + linear_alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool(index=index, vectorizer=vectorizer, config=config) + ``` + Attributes: text_field_name: Name of the text field for BM25 search. vector_field_name: Name of the vector field for similarity search. @@ -274,7 +292,7 @@ class RedisHybridQueryConfig(BaseModel): range_epsilon: Epsilon for RANGE search accuracy. yield_vsim_score_as: Field name to yield vector similarity score as. combination_method: Score combination method - "RRF" or "LINEAR". - linear_alpha: Weight of text score when using LINEAR. + linear_alpha: Weight of text score when using LINEAR (0.0-1.0). rrf_window: Window size for RRF combination. rrf_constant: Constant for RRF combination. yield_combined_score_as: Field name to yield combined score as. @@ -320,7 +338,7 @@ def to_query_kwargs( return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, ) -> dict[str, Any]: - """Convert config to HybridQuery kwargs. + """Convert config to native HybridQuery kwargs. Args: text: The query text for BM25 matching. @@ -357,3 +375,95 @@ def to_query_kwargs( "text_weights": self.text_weights, } + +class RedisAggregatedHybridQueryConfig(BaseModel): + """Configuration for aggregated (client-side) Redis hybrid search queries. + + .. deprecated:: + This config is for older Redis/RedisVL versions. For newer setups + (RedisVL >= 0.13.0, Redis >= 8.4.0), prefer RedisHybridQueryConfig + which uses native server-side hybrid search for better performance. + + This config uses AggregateHybridQuery which performs client-side hybrid + search using FT.AGGREGATE with weighted score combination. It works with + any Redis version that has RediSearch installed. + + Recommended for: + - Redis < 8.4.0 + - RedisVL < 0.13.0 + - Environments where native FT.HYBRID is not available + + Example: + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisAggregatedHybridQueryConfig, + ) + + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool(index=index, vectorizer=vectorizer, config=config) + ``` + + Attributes: + text_field_name: Name of the text field for BM25 search. + vector_field_name: Name of the vector field for similarity search. + text_scorer: Text scoring algorithm (default: "BM25STD"). + alpha: Weight for text score (default: 0.7). Higher values favor + text matching over vector similarity. Combined score is: + alpha * text_score + (1 - alpha) * vector_score + num_results: Number of results to return. + dtype: Data type of the vector. + stopwords: Stopwords to remove from query. + dialect: RediSearch query dialect version. + text_weights: Optional field weights for text scoring. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + text_field_name: str = Field(default="content") + vector_field_name: str = Field(default="embedding") + text_scorer: str = Field(default="BM25STD") + alpha: float = Field(default=0.7, ge=0.0, le=1.0) + num_results: int = Field(default=10, ge=1) + dtype: str = Field(default="float32") + stopwords: Optional[Union[str, Set[str]]] = Field(default="english") + dialect: int = Field(default=2, ge=1) + text_weights: Optional[Dict[str, float]] = Field(default=None) + + def to_query_kwargs( + self, + text: str, + vector: List[float], + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Any] = None, + ) -> dict[str, Any]: + """Convert config to AggregateHybridQuery kwargs. + + Args: + text: The query text for BM25 matching. + vector: The query vector embedding. + return_fields: Optional list of fields to return. + filter_expression: Optional filter expression to apply. + + Returns: + Dictionary of kwargs suitable for AggregateHybridQuery constructor. + """ + return { + "text": text, + "text_field_name": self.text_field_name, + "vector": vector, + "vector_field_name": self.vector_field_name, + "text_scorer": self.text_scorer, + "alpha": self.alpha, + "dtype": self.dtype, + "num_results": self.num_results, + "return_fields": return_fields, + "stopwords": self.stopwords, + "dialect": self.dialect, + "text_weights": self.text_weights, + "filter_expression": filter_expression, + } + diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index 621fe9c..5b8826d 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -16,59 +16,99 @@ from __future__ import annotations +import warnings from typing import Any from typing import List from typing import Optional from typing import Union from google.genai import types +from packaging.version import parse from redisvl.index import AsyncSearchIndex from redisvl.index import SearchIndex -from redisvl.query import HybridQuery from redisvl.utils.vectorize import BaseVectorizer from .base_search_tool import VectorizedSearchTool +from .config import RedisAggregatedHybridQueryConfig from .config import RedisHybridQueryConfig +# Minimum RedisVL version required for native FT.HYBRID support +_MIN_NATIVE_HYBRID_VERSION = "0.13.0" + + +def _get_redisvl_version() -> str: + """Get the installed RedisVL version string. + + Returns: + Version string (e.g., "0.13.0") or "0.0.0" if not available. + """ + try: + import redisvl + + return getattr(redisvl, "__version__", "0.0.0") + except ImportError: + return "0.0.0" + + +def _supports_native_hybrid() -> bool: + """Check if the installed RedisVL version supports native HybridQuery. + + Returns: + True if RedisVL >= 0.13.0, False otherwise. + """ + try: + return parse(_get_redisvl_version()) >= parse(_MIN_NATIVE_HYBRID_VERSION) + except Exception: + return False + class RedisHybridSearchTool(VectorizedSearchTool): """Hybrid search tool combining vector similarity and BM25 text search. This tool performs a hybrid search that combines semantic vector similarity - with keyword-based BM25 text matching using Redis's native FT.HYBRID command. - This is useful when you want to leverage both the semantic understanding of - embeddings and the precision of keyword matching. + with keyword-based BM25 text matching. It automatically detects the installed + RedisVL version and uses the appropriate implementation: - Requirements: - - Redis >= 8.4.0 (for native FT.HYBRID command support) - - redis-py >= 7.1.0 + - **RedisVL >= 0.13.0**: Uses native FT.HYBRID command (server-side fusion) + with RedisHybridQueryConfig. Requires Redis >= 8.4.0. + - **RedisVL < 0.13.0**: Uses AggregateHybridQuery (client-side fusion) + with RedisAggregatedHybridQueryConfig. Works with any Redis version. - Example: + Example (native mode - RedisVL >= 0.13.0): ```python - from redisvl.index import SearchIndex - from redisvl.utils.vectorize import HFTextVectorizer from google.adk_community.tools.redis import ( RedisHybridSearchTool, RedisHybridQueryConfig, ) - index = SearchIndex.from_yaml("schema.yaml") - vectorizer = HFTextVectorizer(model="redis/langcache-embed-v2") - - # Using config object (recommended) config = RedisHybridQueryConfig( text_field_name="content", + combination_method="LINEAR", linear_alpha=0.7, # 70% text, 30% vector - num_results=10, ) tool = RedisHybridSearchTool( index=index, vectorizer=vectorizer, config=config, - return_fields=["title", "content"], ) + ``` - agent = Agent(model="gemini-2.5-flash", tools=[tool]) + Example (aggregate mode - older versions): + ```python + from google.adk_community.tools.redis import ( + RedisHybridSearchTool, + RedisAggregatedHybridQueryConfig, + ) + + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.7, # 70% text, 30% vector + ) + tool = RedisHybridSearchTool( + index=index, + vectorizer=vectorizer, + config=config, + ) ``` """ @@ -77,7 +117,9 @@ def __init__( *, index: Union[SearchIndex, AsyncSearchIndex], vectorizer: BaseVectorizer, - config: Optional[RedisHybridQueryConfig] = None, + config: Optional[ + Union[RedisHybridQueryConfig, RedisAggregatedHybridQueryConfig] + ] = None, return_fields: Optional[List[str]] = None, filter_expression: Optional[Any] = None, name: str = "redis_hybrid_search", @@ -88,13 +130,21 @@ def __init__( Args: index: The RedisVL SearchIndex or AsyncSearchIndex to query. vectorizer: The vectorizer for embedding queries. - config: Configuration for query parameters. If None, uses defaults. - See RedisHybridQueryConfig for available options including - text_field_name, vector_field_name, linear_alpha, and more. + config: Configuration for query parameters. Can be either: + - RedisHybridQueryConfig: For native FT.HYBRID (RedisVL >= 0.13.0) + - RedisAggregatedHybridQueryConfig: For client-side hybrid (older) + If None, auto-detects based on installed RedisVL version. return_fields: Optional list of fields to return in results. filter_expression: Optional filter expression to narrow results. name: The name of the tool (exposed to LLM). description: The description of the tool (exposed to LLM). + + Raises: + ValueError: If RedisHybridQueryConfig is used with RedisVL < 0.13.0. + + Warns: + DeprecationWarning: If RedisAggregatedHybridQueryConfig is used when + native hybrid is available (RedisVL >= 0.13.0). """ super().__init__( name=name, @@ -103,7 +153,37 @@ def __init__( vectorizer=vectorizer, return_fields=return_fields, ) - self._config = config or RedisHybridQueryConfig() + + self._supports_native = _supports_native_hybrid() + + # Auto-detect config if not provided + if config is None: + if self._supports_native: + config = RedisHybridQueryConfig() + else: + config = RedisAggregatedHybridQueryConfig() + + # Validate config compatibility with installed version + self._use_native = isinstance(config, RedisHybridQueryConfig) + + if self._use_native and not self._supports_native: + raise ValueError( + "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >= 8.4.0. " + f"Installed RedisVL version: {_get_redisvl_version()}. " + "Use RedisAggregatedHybridQueryConfig for older versions." + ) + + if not self._use_native and self._supports_native: + warnings.warn( + "RedisAggregatedHybridQueryConfig is deprecated for RedisVL >= 0.13.0. " + "Consider using RedisHybridQueryConfig for native FT.HYBRID support " + "with better performance. RedisAggregatedHybridQueryConfig will " + "continue to work but uses client-side score combination.", + DeprecationWarning, + stacklevel=2, + ) + + self._config = config self._filter_expression = filter_expression def _get_declaration(self) -> types.FunctionDeclaration: @@ -132,8 +212,8 @@ def _get_declaration(self) -> types.FunctionDeclaration: def _build_query( self, query_text: str, embedding: List[float], **kwargs: Any - ) -> HybridQuery: - """Build a HybridQuery for combined vector + text search. + ) -> Any: + """Build a query for combined vector + text search. Args: query_text: The original query text for BM25 matching. @@ -141,7 +221,7 @@ def _build_query( **kwargs: Additional parameters (e.g., num_results). Returns: - A HybridQuery configured for hybrid search. + A HybridQuery or AggregateHybridQuery configured for hybrid search. """ # Allow runtime override of num_results num_results = kwargs.get("num_results", self._config.num_results) @@ -155,4 +235,11 @@ def _build_query( ) query_kwargs["num_results"] = num_results - return HybridQuery(**query_kwargs) + if self._use_native: + from redisvl.query import HybridQuery + + return HybridQuery(**query_kwargs) + else: + from redisvl.query import AggregateHybridQuery + + return AggregateHybridQuery(**query_kwargs) diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py index 08de179..6d1c6b9 100644 --- a/tests/unittests/tools/redis/test_hybrid_search_tool.py +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -14,8 +14,10 @@ """Tests for RedisHybridSearchTool.""" +import warnings from unittest.mock import AsyncMock from unittest.mock import MagicMock +from unittest.mock import patch import pytest @@ -23,11 +25,15 @@ pytest.importorskip("redisvl") from redisvl.index import SearchIndex -from redisvl.query import HybridQuery from redisvl.utils.vectorize import BaseVectorizer +from google.adk_community.tools.redis import RedisAggregatedHybridQueryConfig from google.adk_community.tools.redis import RedisHybridQueryConfig from google.adk_community.tools.redis import RedisHybridSearchTool +from google.adk_community.tools.redis.hybrid_search_tool import ( + _get_redisvl_version, + _supports_native_hybrid, +) @pytest.fixture @@ -51,135 +57,241 @@ def mock_index(): return index -@pytest.fixture -def hybrid_search_tool(mock_index, mock_vectorizer): - """Create RedisHybridSearchTool instance for testing.""" - config = RedisHybridQueryConfig( - text_field_name="content", - num_results=5, - ) - return RedisHybridSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, - config=config, - return_fields=["title", "content"], - ) +def _native_hybrid_available(): + """Check if native HybridQuery dependencies are available.""" + try: + from redis.commands.search.hybrid_query import CombineResultsMethod + from redis.commands.search.hybrid_query import HybridPostProcessingConfig + return True + except (ImportError, ModuleNotFoundError): + return False -class TestRedisHybridSearchToolInit: - """Tests for RedisHybridSearchTool initialization.""" - def test_default_parameters(self, mock_index, mock_vectorizer): - """Test default parameter values with default config.""" - tool = RedisHybridSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, - ) - # Config defaults - assert tool._config.text_field_name == "content" - assert tool._config.vector_field_name == "embedding" - assert tool._config.text_scorer == "BM25STD" - assert tool._config.combination_method is None - assert tool._config.linear_alpha == 0.3 - assert tool._config.rrf_window == 20 - assert tool._config.rrf_constant == 60 - assert tool._config.num_results == 10 - assert tool._config.dtype == "float32" - assert tool._config.stopwords == "english" - # Tool-level defaults - assert tool._filter_expression is None - - def test_custom_parameters_via_config(self, mock_index, mock_vectorizer): - """Test custom parameter values via config object.""" +class TestVersionDetection: + """Tests for version detection functions.""" + + def test_get_redisvl_version(self): + """Test version string retrieval.""" + version = _get_redisvl_version() + assert isinstance(version, str) + # Should be a valid version string like "0.13.0" or "0.0.0" + assert len(version.split(".")) >= 2 + + def test_supports_native_hybrid(self): + """Test native hybrid support detection.""" + result = _supports_native_hybrid() + assert isinstance(result, bool) + + +class TestRedisHybridQueryConfig: + """Tests for RedisHybridQueryConfig (native mode).""" + + def test_default_values(self): + """Test default config values.""" + config = RedisHybridQueryConfig() + assert config.text_field_name == "content" + assert config.vector_field_name == "embedding" + assert config.text_scorer == "BM25STD" + assert config.combination_method is None + assert config.linear_alpha == 0.3 + assert config.rrf_window == 20 + assert config.rrf_constant == 60 + assert config.num_results == 10 + assert config.dtype == "float32" + assert config.stopwords == "english" + + def test_to_query_kwargs(self): + """Test conversion to query kwargs.""" config = RedisHybridQueryConfig( - text_field_name="description", - vector_field_name="vec", - text_scorer="TFIDF", + text_field_name="content", combination_method="LINEAR", linear_alpha=0.7, - rrf_window=30, - rrf_constant=80, - num_results=20, - dtype="float64", - stopwords={"the", "a", "an"}, ) - tool = RedisHybridSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, - config=config, - return_fields=["title", "url"], + kwargs = config.to_query_kwargs( + text="test query", + vector=[0.1] * 384, + return_fields=["title"], ) - assert tool._config.text_field_name == "description" - assert tool._config.vector_field_name == "vec" - assert tool._config.text_scorer == "TFIDF" - assert tool._config.combination_method == "LINEAR" - assert tool._config.linear_alpha == 0.7 - assert tool._config.rrf_window == 30 - assert tool._config.rrf_constant == 80 - assert tool._config.num_results == 20 - assert tool._return_fields == ["title", "url"] - assert tool._config.dtype == "float64" - assert tool._config.stopwords == {"the", "a", "an"} + assert kwargs["text"] == "test query" + assert kwargs["text_field_name"] == "content" + assert kwargs["combination_method"] == "LINEAR" + assert kwargs["linear_alpha"] == 0.7 + assert kwargs["return_fields"] == ["title"] - def test_custom_name_and_description(self, mock_index, mock_vectorizer): - """Test custom tool name and description.""" - tool = RedisHybridSearchTool( - index=mock_index, - vectorizer=mock_vectorizer, - name="custom_hybrid", - description="Custom hybrid search", + +class TestRedisAggregatedHybridQueryConfig: + """Tests for RedisAggregatedHybridQueryConfig (legacy mode).""" + + def test_default_values(self): + """Test default config values.""" + config = RedisAggregatedHybridQueryConfig() + assert config.text_field_name == "content" + assert config.vector_field_name == "embedding" + assert config.text_scorer == "BM25STD" + assert config.alpha == 0.7 + assert config.num_results == 10 + assert config.dtype == "float32" + assert config.dialect == 2 + + def test_to_query_kwargs(self): + """Test conversion to query kwargs.""" + config = RedisAggregatedHybridQueryConfig( + text_field_name="content", + alpha=0.8, + ) + kwargs = config.to_query_kwargs( + text="test query", + vector=[0.1] * 384, + return_fields=["title"], ) - assert tool.name == "custom_hybrid" - assert tool.description == "Custom hybrid search" + assert kwargs["text"] == "test query" + assert kwargs["text_field_name"] == "content" + assert kwargs["alpha"] == 0.8 + assert kwargs["return_fields"] == ["title"] + assert "linear_alpha" not in kwargs + assert "combination_method" not in kwargs -def _hybrid_query_available(): - """Check if HybridQuery dependencies are available.""" - try: - from redis.commands.search.hybrid_query import CombineResultsMethod - from redis.commands.search.hybrid_query import HybridPostProcessingConfig +class TestRedisHybridSearchToolInit: + """Tests for RedisHybridSearchTool initialization.""" - return True - except (ImportError, ModuleNotFoundError): - return False + def test_auto_detect_config_native(self, mock_index, mock_vectorizer): + """Test auto-detection uses native config when supported.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert isinstance(tool._config, RedisHybridQueryConfig) + assert tool._use_native is True + + def test_auto_detect_config_aggregate(self, mock_index, mock_vectorizer): + """Test auto-detection uses aggregate config when native not supported.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + ) + assert isinstance(tool._config, RedisAggregatedHybridQueryConfig) + assert tool._use_native is False + + def test_native_config_on_old_version_raises( + self, mock_index, mock_vectorizer + ): + """Test that using native config on old version raises ValueError.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + with pytest.raises(ValueError, match="RedisHybridQueryConfig requires"): + RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(), + ) + + def test_aggregate_config_on_new_version_warns( + self, mock_index, mock_vectorizer + ): + """Test that using aggregate config on new version emits deprecation warning.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisAggregatedHybridQueryConfig(), + ) + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "deprecated" in str(w[0].message).lower() + assert tool._use_native is False + + def test_custom_name_and_description(self, mock_index, mock_vectorizer): + """Test custom tool name and description.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + name="custom_hybrid", + description="Custom hybrid search", + ) + assert tool.name == "custom_hybrid" + assert tool.description == "Custom hybrid search" class TestRedisHybridSearchToolBuildQuery: """Tests for _build_query method.""" @pytest.mark.skipif( - not _hybrid_query_available(), + not _native_hybrid_available(), reason="HybridQuery requires redis-py>=7.1.0 and Redis>=8.4.0", ) - def test_build_query_basic(self, hybrid_search_tool): - """Test basic query building.""" - embedding = [0.1] * 384 - query = hybrid_search_tool._build_query("test query", embedding) + def test_build_query_native(self, mock_index, mock_vectorizer): + """Test query building with native HybridQuery.""" + from redisvl.query import HybridQuery - assert isinstance(query, HybridQuery) + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(num_results=5), + ) + embedding = [0.1] * 384 + query = tool._build_query("test query", embedding) + assert isinstance(query, HybridQuery) - @pytest.mark.skipif( - not _hybrid_query_available(), - reason="HybridQuery requires redis-py>=7.1.0 and Redis>=8.4.0", - ) - def test_build_query_with_num_results_override(self, hybrid_search_tool): - """Test query building with num_results override.""" - embedding = [0.1] * 384 - query = hybrid_search_tool._build_query( - "test query", embedding, num_results=15 - ) + def test_build_query_aggregate(self, mock_index, mock_vectorizer): + """Test query building with AggregateHybridQuery.""" + from redisvl.query import AggregateHybridQuery - assert query._num_results == 15 + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=False, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisAggregatedHybridQueryConfig(num_results=5), + ) + embedding = [0.1] * 384 + query = tool._build_query("test query", embedding) + assert isinstance(query, AggregateHybridQuery) class TestRedisHybridSearchToolDeclaration: """Tests for _get_declaration method.""" - def test_get_declaration(self, hybrid_search_tool): + def test_get_declaration(self, mock_index, mock_vectorizer): """Test function declaration generation.""" - declaration = hybrid_search_tool._get_declaration() + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._supports_native_hybrid", + return_value=True, + ): + tool = RedisHybridSearchTool( + index=mock_index, + vectorizer=mock_vectorizer, + config=RedisHybridQueryConfig(num_results=5), + ) + declaration = tool._get_declaration() - assert declaration.name == "redis_hybrid_search" - assert "query" in declaration.parameters.properties - assert "num_results" in declaration.parameters.properties - assert "query" in declaration.parameters.required + assert declaration.name == "redis_hybrid_search" + assert "query" in declaration.parameters.properties + assert "num_results" in declaration.parameters.properties + assert "query" in declaration.parameters.required From 25eec898345ed65f54ed90826e5da2254db4a10c Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Mon, 12 Jan 2026 14:14:06 +0200 Subject: [PATCH 14/16] chore(deps): remove upper bound from redis dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6172336..d4f70ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "google-adk", # Google ADK "httpx>=0.27.0, <1.0.0", # For OpenMemory service - "redis>=5.0.0, <6.0.0", # Redis for session storage + "redis>=5.0.0", # Redis for session storage # go/keep-sorted end "orjson>=3.11.3", ] From 17f2f0f3b3786076a99a2892016ca29155b3ecd8 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Mon, 12 Jan 2026 14:24:38 +0200 Subject: [PATCH 15/16] style: run autoformat (isort + pyink) --- src/google/adk_community/tools/redis/config.py | 5 +++-- .../tools/redis/hybrid_search_tool.py | 16 ++++++++-------- .../tools/redis/test_hybrid_search_tool.py | 8 +++----- .../tools/redis/test_vector_search_tool.py | 4 +++- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/google/adk_community/tools/redis/config.py b/src/google/adk_community/tools/redis/config.py index ddade46..465b9da 100644 --- a/src/google/adk_community/tools/redis/config.py +++ b/src/google/adk_community/tools/redis/config.py @@ -34,7 +34,9 @@ from pydantic import Field # Type alias for sort specification -SortSpec = Optional[Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]]] +SortSpec = Optional[ + Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]] +] class RedisVectorQueryConfig(BaseModel): @@ -466,4 +468,3 @@ def to_query_kwargs( "text_weights": self.text_weights, "filter_expression": filter_expression, } - diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index 5b8826d..fba2039 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -16,11 +16,11 @@ from __future__ import annotations -import warnings from typing import Any from typing import List from typing import Optional from typing import Union +import warnings from google.genai import types from packaging.version import parse @@ -168,17 +168,17 @@ def __init__( if self._use_native and not self._supports_native: raise ValueError( - "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >= 8.4.0. " - f"Installed RedisVL version: {_get_redisvl_version()}. " - "Use RedisAggregatedHybridQueryConfig for older versions." + "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >=" + f" 8.4.0. Installed RedisVL version: {_get_redisvl_version()}. Use" + " RedisAggregatedHybridQueryConfig for older versions." ) if not self._use_native and self._supports_native: warnings.warn( - "RedisAggregatedHybridQueryConfig is deprecated for RedisVL >= 0.13.0. " - "Consider using RedisHybridQueryConfig for native FT.HYBRID support " - "with better performance. RedisAggregatedHybridQueryConfig will " - "continue to work but uses client-side score combination.", + "RedisAggregatedHybridQueryConfig is deprecated for RedisVL >=" + " 0.13.0. Consider using RedisHybridQueryConfig for native FT.HYBRID" + " support with better performance. RedisAggregatedHybridQueryConfig" + " will continue to work but uses client-side score combination.", DeprecationWarning, stacklevel=2, ) diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py index 6d1c6b9..03ab9e3 100644 --- a/tests/unittests/tools/redis/test_hybrid_search_tool.py +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -14,10 +14,10 @@ """Tests for RedisHybridSearchTool.""" -import warnings from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch +import warnings import pytest @@ -30,10 +30,8 @@ from google.adk_community.tools.redis import RedisAggregatedHybridQueryConfig from google.adk_community.tools.redis import RedisHybridQueryConfig from google.adk_community.tools.redis import RedisHybridSearchTool -from google.adk_community.tools.redis.hybrid_search_tool import ( - _get_redisvl_version, - _supports_native_hybrid, -) +from google.adk_community.tools.redis.hybrid_search_tool import _get_redisvl_version +from google.adk_community.tools.redis.hybrid_search_tool import _supports_native_hybrid @pytest.fixture diff --git a/tests/unittests/tools/redis/test_vector_search_tool.py b/tests/unittests/tools/redis/test_vector_search_tool.py index 05900d9..a71e9fd 100644 --- a/tests/unittests/tools/redis/test_vector_search_tool.py +++ b/tests/unittests/tools/redis/test_vector_search_tool.py @@ -204,7 +204,9 @@ def test_build_query_passes_core_parameters( assert query._dialect == 3 assert query._in_order is True - def test_stores_optional_parameters_in_config(self, mock_index, mock_vectorizer): + def test_stores_optional_parameters_in_config( + self, mock_index, mock_vectorizer + ): """Test that optional parameters are stored correctly in config.""" config = RedisVectorQueryConfig( hybrid_policy="ADHOC_BF", From fa3b9fd6314a34649c663224fec2c3e17f205e4d Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Tue, 20 Jan 2026 11:03:02 -0500 Subject: [PATCH 16/16] feat(redis): check both redisvl and Redis server versions for native hybrid support - Add _MIN_REDIS_SERVER_VERSION constant (8.4.0) for FT.HYBRID command - Add _get_redis_server_version() to retrieve Redis server version from index client - Update _supports_native_hybrid() to check both redisvl >= 0.13.0 AND Redis >= 8.4.0 - Add proper logging: debug for version mismatches, warning for parsing errors - Handle edge cases: None client, connection errors, version parsing failures - Update error messages to include both version numbers - Add comprehensive tests for Redis server version detection --- .../tools/redis/hybrid_search_tool.py | 116 ++++++++++++++++-- .../tools/redis/test_hybrid_search_tool.py | 84 ++++++++++++- 2 files changed, 184 insertions(+), 16 deletions(-) diff --git a/src/google/adk_community/tools/redis/hybrid_search_tool.py b/src/google/adk_community/tools/redis/hybrid_search_tool.py index fba2039..28ebbf5 100644 --- a/src/google/adk_community/tools/redis/hybrid_search_tool.py +++ b/src/google/adk_community/tools/redis/hybrid_search_tool.py @@ -16,6 +16,7 @@ from __future__ import annotations +import logging from typing import Any from typing import List from typing import Optional @@ -34,6 +35,10 @@ # Minimum RedisVL version required for native FT.HYBRID support _MIN_NATIVE_HYBRID_VERSION = "0.13.0" +# Minimum Redis server version required for native FT.HYBRID support +_MIN_REDIS_SERVER_VERSION = "8.4.0" + +logger = logging.getLogger(__name__) def _get_redisvl_version() -> str: @@ -45,22 +50,112 @@ def _get_redisvl_version() -> str: try: import redisvl - return getattr(redisvl, "__version__", "0.0.0") + version = getattr(redisvl, "__version__", None) + if version is None: + logger.debug("redisvl.__version__ not found, assuming 0.0.0") + return "0.0.0" + return version except ImportError: + # This shouldn't normally happen due to module-level guard in __init__.py + logger.debug("redisvl not importable, assuming version 0.0.0") + return "0.0.0" + + +def _get_redis_server_version( + index: Union[SearchIndex, AsyncSearchIndex], +) -> str: + """Get the Redis server version from the index's client. + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex. + + Returns: + Version string (e.g., "8.4.0") or "0.0.0" if not available. + """ + try: + # For sync index, use _redis_client to trigger lazy connection if needed + if isinstance(index, SearchIndex): + client = index._redis_client + else: + # For async index, we can only use the public property + # which may be None if not yet connected + client = index.client + + if client is None: + logger.warning( + "Redis client not available. For AsyncSearchIndex, ensure the " + "index has been used (e.g., await index.create()) before " + "creating RedisHybridSearchTool, or native hybrid support " + "detection will be skipped." + ) + return "0.0.0" + + info = client.info("server") + return info.get("redis_version", "0.0.0") + except Exception as e: + logger.warning( + "Could not determine Redis server version: %s. " + "Native hybrid search may not be available.", + e, + ) return "0.0.0" -def _supports_native_hybrid() -> bool: - """Check if the installed RedisVL version supports native HybridQuery. +def _supports_native_hybrid( + index: Union[SearchIndex, AsyncSearchIndex], +) -> bool: + """Check if native HybridQuery is supported. + + Native hybrid search requires both: + - RedisVL >= 0.13.0 + - Redis server >= 8.4.0 + + Args: + index: The RedisVL SearchIndex or AsyncSearchIndex to check. Returns: - True if RedisVL >= 0.13.0, False otherwise. + True if both version requirements are met, False otherwise. """ + # Check redisvl version + redisvl_version = _get_redisvl_version() try: - return parse(_get_redisvl_version()) >= parse(_MIN_NATIVE_HYBRID_VERSION) - except Exception: + if parse(redisvl_version) < parse(_MIN_NATIVE_HYBRID_VERSION): + logger.debug( + "Native hybrid not supported: RedisVL %s < %s", + redisvl_version, + _MIN_NATIVE_HYBRID_VERSION, + ) + return False + except Exception as e: + logger.warning( + "Could not parse redisvl version '%s': %s. " + "Native hybrid search will be disabled.", + redisvl_version, + e, + ) return False + # Check Redis server version + redis_version = _get_redis_server_version(index) + try: + if parse(redis_version) < parse(_MIN_REDIS_SERVER_VERSION): + logger.debug( + "Native hybrid not supported: Redis server %s < %s", + redis_version, + _MIN_REDIS_SERVER_VERSION, + ) + return False + except Exception as e: + logger.warning( + "Could not parse Redis server version '%s': %s. " + "Native hybrid search will be disabled.", + redis_version, + e, + ) + return False + + return True + class RedisHybridSearchTool(VectorizedSearchTool): """Hybrid search tool combining vector similarity and BM25 text search. @@ -154,7 +249,7 @@ def __init__( return_fields=return_fields, ) - self._supports_native = _supports_native_hybrid() + self._supports_native = _supports_native_hybrid(index) # Auto-detect config if not provided if config is None: @@ -168,9 +263,10 @@ def __init__( if self._use_native and not self._supports_native: raise ValueError( - "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >=" - f" 8.4.0. Installed RedisVL version: {_get_redisvl_version()}. Use" - " RedisAggregatedHybridQueryConfig for older versions." + "RedisHybridQueryConfig requires RedisVL >= 0.13.0 and Redis >= " + f"8.4.0. Installed RedisVL version: {_get_redisvl_version()}, " + f"Redis server version: {_get_redis_server_version(index)}. " + "Use RedisAggregatedHybridQueryConfig for older versions." ) if not self._use_native and self._supports_native: diff --git a/tests/unittests/tools/redis/test_hybrid_search_tool.py b/tests/unittests/tools/redis/test_hybrid_search_tool.py index 03ab9e3..606b162 100644 --- a/tests/unittests/tools/redis/test_hybrid_search_tool.py +++ b/tests/unittests/tools/redis/test_hybrid_search_tool.py @@ -30,6 +30,7 @@ from google.adk_community.tools.redis import RedisAggregatedHybridQueryConfig from google.adk_community.tools.redis import RedisHybridQueryConfig from google.adk_community.tools.redis import RedisHybridSearchTool +from google.adk_community.tools.redis.hybrid_search_tool import _get_redis_server_version from google.adk_community.tools.redis.hybrid_search_tool import _get_redisvl_version from google.adk_community.tools.redis.hybrid_search_tool import _supports_native_hybrid @@ -44,14 +45,24 @@ def mock_vectorizer(): @pytest.fixture -def mock_index(): - """Mock RedisVL SearchIndex.""" +def mock_redis_client(): + """Mock Redis client with info method.""" + client = MagicMock() + client.info = MagicMock(return_value={"redis_version": "8.4.0"}) + return client + + +@pytest.fixture +def mock_index(mock_redis_client): + """Mock RedisVL SearchIndex with Redis client.""" index = MagicMock(spec=SearchIndex) index.query = MagicMock( return_value=[ {"title": "Test Doc", "content": "Test content", "score": 0.9} ] ) + # Mock _redis_client property for sync index + type(index)._redis_client = property(lambda self: mock_redis_client) return index @@ -76,10 +87,71 @@ def test_get_redisvl_version(self): # Should be a valid version string like "0.13.0" or "0.0.0" assert len(version.split(".")) >= 2 - def test_supports_native_hybrid(self): - """Test native hybrid support detection.""" - result = _supports_native_hybrid() - assert isinstance(result, bool) + def test_get_redis_server_version(self, mock_index): + """Test Redis server version retrieval.""" + version = _get_redis_server_version(mock_index) + assert version == "8.4.0" + + def test_get_redis_server_version_client_none(self): + """Test Redis server version when client is None.""" + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: None) + version = _get_redis_server_version(index) + assert version == "0.0.0" + + def test_get_redis_server_version_exception(self, mock_index): + """Test Redis server version when info() raises exception.""" + mock_index._redis_client.info.side_effect = Exception("Connection error") + version = _get_redis_server_version(mock_index) + assert version == "0.0.0" + + def test_supports_native_hybrid_both_versions_ok(self, mock_index): + """Test native hybrid support when both versions meet requirements.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.13.0", + ): + result = _supports_native_hybrid(mock_index) + assert result is True + + def test_supports_native_hybrid_redisvl_too_old(self, mock_index): + """Test native hybrid not supported when redisvl version is too old.""" + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.12.0", + ): + result = _supports_native_hybrid(mock_index) + assert result is False + + def test_supports_native_hybrid_redis_server_too_old(self): + """Test native hybrid not supported when Redis server version is too old.""" + # Create index with old Redis server version + mock_client = MagicMock() + mock_client.info = MagicMock(return_value={"redis_version": "7.2.0"}) + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: mock_client) + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.13.0", + ): + result = _supports_native_hybrid(index) + assert result is False + + def test_supports_native_hybrid_both_versions_too_old(self): + """Test native hybrid not supported when both versions are too old.""" + # Create index with old Redis server version + mock_client = MagicMock() + mock_client.info = MagicMock(return_value={"redis_version": "7.2.0"}) + index = MagicMock(spec=SearchIndex) + type(index)._redis_client = property(lambda self: mock_client) + + with patch( + "google.adk_community.tools.redis.hybrid_search_tool._get_redisvl_version", + return_value="0.12.0", + ): + result = _supports_native_hybrid(index) + assert result is False class TestRedisHybridQueryConfig: