diff --git a/contributing/samples/milvus_rag_agent/README.md b/contributing/samples/milvus_rag_agent/README.md new file mode 100644 index 0000000000..a6437c6c73 --- /dev/null +++ b/contributing/samples/milvus_rag_agent/README.md @@ -0,0 +1,86 @@ +# Milvus RAG Agent Sample + +This sample demonstrates how to build a knowledge base agent using +[Milvus](https://milvus.io/) as the vector database for +retrieval-augmented generation (RAG) with Google ADK. + +## Prerequisites + +1. A running Milvus instance, or use Milvus Lite (local file path). +2. A Google GenAI API key for embedding generation. +3. Install dependencies: + +```bash +pip install "google-adk[milvus]" +``` + +## Supported Milvus Backends + +The integration works with all three Milvus deployment modes by changing +the `uri` (and optionally `token`) setting: + +| Backend | `uri` | `token` | +|---------|-------|---------| +| **Milvus Lite** (local, no server needed) | `./milvus.db` | — | +| **Milvus Server** (self-hosted) | `http://localhost:19530` | — | +| **Zilliz Cloud** (fully managed) | `https://in01-xxx.serverless.gcp-us-west1.cloud.zilliz.com` | `your-api-key` | + +> For Milvus Lite, install the extra package: `pip install "pymilvus[milvus_lite]"`. + +## Setup + +### Environment Variables + +```bash +export GOOGLE_API_KEY="your-google-api-key" + +# Pick one of the following: +export MILVUS_URI="./milvus.db" # Milvus Lite +export MILVUS_URI="http://localhost:19530" # Milvus Server +export MILVUS_URI="https://in01-xxx.cloud.zilliz.com" # Zilliz Cloud + +export MILVUS_COLLECTION="knowledge_base" +# Required for Zilliz Cloud only: +# export MILVUS_TOKEN="your-api-key" +``` + +### Data Ingestion + +Before running the agent, you need to populate the Milvus collection +with your knowledge base data: + +```python +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +from google.genai import Client + +# Define your embedding function (example using Google GenAI). +genai_client = Client() + +def embedding_fn(texts): + resp = genai_client.models.embed_content( + model="text-embedding-004", contents=texts) + return [list(e.values) for e in resp.embeddings] + +settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + uri="http://localhost:19530", + collection_name="knowledge_base", + dimension=768, + ), +) + +store = MilvusVectorStore(settings=settings, embedding_fn=embedding_fn) +store.setup() +store.add_contents([ + "Your document text here...", + "Another document...", +]) +``` + +## Run + +```bash +adk run contributing/samples/milvus_rag_agent +``` diff --git a/contributing/samples/milvus_rag_agent/__init__.py b/contributing/samples/milvus_rag_agent/__init__.py new file mode 100644 index 0000000000..4015e47d6e --- /dev/null +++ b/contributing/samples/milvus_rag_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/milvus_rag_agent/agent.py b/contributing/samples/milvus_rag_agent/agent.py new file mode 100644 index 0000000000..9c236f287c --- /dev/null +++ b/contributing/samples/milvus_rag_agent/agent.py @@ -0,0 +1,85 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample Milvus RAG Agent. + +This sample demonstrates how to build a knowledge base agent using Milvus +as the vector database for retrieval-augmented generation (RAG). + +Prerequisites: + 1. A running Milvus instance (or use Milvus Lite with a local file path). + 2. A Google GenAI API key set in the GOOGLE_API_KEY environment variable. + 3. Install dependencies: pip install "google-adk[milvus]" + +Usage: + adk run contributing/samples/milvus_rag_agent +""" + +import os + +from dotenv import load_dotenv +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.milvus.milvus_toolset import MilvusToolset +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +from google.genai import Client + +load_dotenv() + +# --- Embedding function using Google GenAI --- +genai_client = Client() + + +def embedding_fn(texts: list[str]) -> list[list[float]]: + response = genai_client.models.embed_content( + model="text-embedding-004", + contents=texts, + ) + return [list(e.values) for e in response.embeddings] + + +# --- Milvus vector store settings --- +# Replace these with your own Milvus connection and collection settings. +vector_store_settings = MilvusVectorStoreSettings( + # Use a remote Milvus instance or a local Milvus Lite file path. + uri=os.environ.get("MILVUS_URI", "http://localhost:19530"), + token=os.environ.get("MILVUS_TOKEN", None), + collection_name=os.environ.get("MILVUS_COLLECTION", "knowledge_base"), + dimension=768, + metric_type="COSINE", + top_k=5, +) + +# --- Milvus toolset --- +milvus_toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings( + vector_store_settings=vector_store_settings, + ), + embedding_fn=embedding_fn, + tool_filter=["similarity_search"], +) + +# --- Agent definition --- +root_agent = LlmAgent( + model="gemini-2.5-flash", + name="milvus_knowledge_agent", + description="Agent that answers questions using a Milvus knowledge base.", + instruction=""" + You are a helpful assistant with access to a knowledge base. + 1. Always use the `similarity_search` tool to find relevant information. + 2. Present the search results naturally in your response. + 3. If no results are found, say you don't know. + """, + tools=[milvus_toolset], +) diff --git a/pyproject.toml b/pyproject.toml index 2564e8baf1..05ee68b01b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ test = [ "litellm>=1.75.5, <1.80.17", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM + "pymilvus>=2.5.0", # For Milvus vector store tests "opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0", "pypika>=0.50.0", # For crewai->chromadb dependency "pytest-asyncio>=0.25.0", @@ -159,11 +160,16 @@ extensions = [ "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. + "pymilvus>=2.5.0", # For Milvus vector store integration. "lxml>=5.3.0", # For load_web_page tool. "pypika>=0.50.0", # For crewai->chromadb dependency "toolbox-adk>=0.5.7, <0.6.0", # For tools.toolbox_toolset.ToolboxToolset ] +milvus = [ + "pymilvus>=2.5.0", +] + otel-gcp = ["opentelemetry-instrumentation-google-genai>=0.6b0, <1.0.0"] toolbox = ["toolbox-adk>=0.5.7, <0.6.0"] diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 154c77caf6..1341355501 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -38,6 +38,9 @@ class FeatureName(str, Enum): GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG" GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" + MILVUS_MEMORY_SERVICE = "MILVUS_MEMORY_SERVICE" + MILVUS_TOOLSET = "MILVUS_TOOLSET" + MILVUS_VECTOR_STORE = "MILVUS_VECTOR_STORE" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG" PUBSUB_TOOLSET = "PUBSUB_TOOLSET" @@ -114,6 +117,15 @@ class FeatureConfig: FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig( FeatureStage.WIP, default_on=False ), + FeatureName.MILVUS_MEMORY_SERVICE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.MILVUS_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.MILVUS_VECTOR_STORE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..c78a3ced5c 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -35,3 +35,14 @@ ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) + +try: + from .milvus_memory_service import MilvusMemoryService + + __all__.append('MilvusMemoryService') +except ImportError: + logger.debug( + 'pymilvus is not installed. If you want to use the' + ' MilvusMemoryService please install it with:' + ' pip install "google-adk[milvus]". If not, you can ignore this warning.' + ) diff --git a/src/google/adk/memory/milvus_memory_service.py b/src/google/adk/memory/milvus_memory_service.py new file mode 100644 index 0000000000..39cbbd8db3 --- /dev/null +++ b/src/google/adk/memory/milvus_memory_service.py @@ -0,0 +1,349 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus-backed memory service for cross-session conversation memory.""" + +from __future__ import annotations + +import logging +from typing import Callable +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types +from typing_extensions import override + +from ..features import experimental +from ..features import FeatureName +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +try: + from pymilvus import DataType + from pymilvus import MilvusClient +except ImportError: + MilvusClient = None + DataType = None + + +def _should_skip_event(content: types.Content) -> bool: + """Returns True if the event has no user-readable content.""" + if not content or not content.parts: + return True + for part in content.parts: + if part.text or part.inline_data or part.file_data: + return False + return True + + +def _extract_event_text(content: types.Content) -> str: + """Extracts and joins all text parts from event content.""" + if not content or not content.parts: + return "" + return " ".join(part.text for part in content.parts if part.text) + + +@experimental(FeatureName.MILVUS_MEMORY_SERVICE) +class MilvusMemoryService(BaseMemoryService): + """Memory service backed by Milvus vector database. + + Stores session events as vector-embedded text in a Milvus collection, + enabling semantic search across past conversations. + + Supports all Milvus deployment modes: + - Milvus Lite (local file path, e.g., ``"./memory.db"``) + - Milvus Server (e.g., ``"http://localhost:19530"``) + - Zilliz Cloud (e.g., ``"https://in01-xxx.cloud.zilliz.com"`` + token) + """ + + def __init__( + self, + *, + embedding_fn: Callable[[list[str]], list[list[float]]], + collection_name: str = "adk_memory", + uri: str = "http://localhost:19530", + token: Optional[str] = None, + db_name: str = "default", + dimension: int = 768, + metric_type: str = "COSINE", + top_k: int = 10, + ): + """Initializes the MilvusMemoryService. + + Args: + embedding_fn: A function that takes a list of texts and returns a + list of embedding vectors. + collection_name: The Milvus collection name for storing memories. + uri: The Milvus server URI or local file path. + token: Optional authentication token (e.g., for Zilliz Cloud). + db_name: The Milvus database name. + dimension: The dimension of the embedding vectors. + metric_type: The distance metric (COSINE, L2, or IP). + top_k: Default number of results for memory search. + + Raises: + ImportError: If pymilvus is not installed. + """ + if MilvusClient is None: + raise ImportError( + "pymilvus package not found. " + 'Please install with: pip install "google-adk[milvus]"' + ) + + self._embedding_fn = embedding_fn + self._collection_name = collection_name + self._dimension = dimension + self._metric_type = metric_type + self._top_k = top_k + self._collection_ready = False + + self._client = MilvusClient( + uri=uri, + token=token, + db_name=db_name, + ) + + def _ensure_collection(self) -> None: + """Creates the collection and index if they do not exist (idempotent).""" + if self._collection_ready: + return + + if self._client.has_collection(self._collection_name): + logger.info( + "Memory collection '%s' already exists.", + self._collection_name, + ) + self._collection_ready = True + return + + schema = self._client.create_schema( + auto_id=True, enable_dynamic_field=True + ) + schema.add_field( + field_name="id", + datatype=DataType.INT64, + is_primary=True, + ) + schema.add_field( + field_name="app_name", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="user_id", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="session_id", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="author", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + ) + schema.add_field( + field_name="timestamp", + datatype=DataType.DOUBLE, + ) + schema.add_field( + field_name="embedding", + datatype=DataType.FLOAT_VECTOR, + dim=self._dimension, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name="embedding", + index_type="AUTOINDEX", + metric_type=self._metric_type, + ) + + self._client.create_collection( + collection_name=self._collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + "Created memory collection '%s' (dim=%d, metric=%s).", + self._collection_name, + self._dimension, + self._metric_type, + ) + self._collection_ready = True + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Adds session events to Milvus as vector-embedded memories. + + Filters out events with no user-readable content (e.g., function + calls). Deduplicates by checking existing events for the same + session before inserting. + + Args: + session: The session whose events will be stored. + """ + self._ensure_collection() + + # Collect events with meaningful text content. + events_to_store = [] + for event in session.events: + if _should_skip_event(event.content): + continue + text = _extract_event_text(event.content) + if not text.strip(): + continue + events_to_store.append((event, text)) + + if not events_to_store: + logger.info("No events to add to memory for session %s.", session.id) + return + + # Deduplicate: find which timestamps already exist for this session. + existing_timestamps = set() + try: + existing = self._client.query( + collection_name=self._collection_name, + filter=( + f'app_name == "{session.app_name}"' + f' and user_id == "{session.user_id}"' + f' and session_id == "{session.id}"' + ), + output_fields=["timestamp"], + ) + existing_timestamps = {row["timestamp"] for row in existing} + except Exception: + logger.debug( + "Could not query existing events for dedup, inserting all.", + exc_info=True, + ) + + new_events = [ + (event, text) + for event, text in events_to_store + if event.timestamp not in existing_timestamps + ] + + if not new_events: + logger.info( + "All events for session %s already in memory.", session.id + ) + return + + # Embed and insert. + texts = [text for _, text in new_events] + embeddings = self._embedding_fn(texts) + + data = [] + for (event, text), embedding in zip(new_events, embeddings): + data.append({ + "app_name": session.app_name, + "user_id": session.user_id, + "session_id": session.id, + "author": event.author or "", + "content": text, + "timestamp": event.timestamp, + "embedding": embedding, + }) + + self._client.insert( + collection_name=self._collection_name, + data=data, + ) + + logger.info( + "Added %d events from session %s to memory collection '%s'.", + len(data), + session.id, + self._collection_name, + ) + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Searches memories by semantic similarity, scoped to the user. + + Args: + app_name: The application name. + user_id: The user ID. + query: The natural-language search query. + + Returns: + A SearchMemoryResponse with matching MemoryEntry objects. + """ + self._ensure_collection() + + query_embedding = self._embedding_fn([query])[0] + + filter_expr = ( + f'app_name == "{app_name}" and user_id == "{user_id}"' + ) + + results = self._client.search( + collection_name=self._collection_name, + data=[query_embedding], + limit=self._top_k, + output_fields=["content", "author", "timestamp"], + search_params={"metric_type": self._metric_type}, + filter=filter_expr, + ) + + if not results or not results[0]: + return SearchMemoryResponse() + + memories = [] + for hit in results[0]: + entity = hit["entity"] + author = entity.get("author", "") + content_text = entity.get("content", "") + timestamp = entity.get("timestamp") + + role = "user" if author == "user" else "model" + memory_entry = MemoryEntry( + content=types.Content( + parts=[types.Part(text=content_text)], + role=role, + ), + author=author, + timestamp=( + _utils.format_timestamp(timestamp) if timestamp else None + ), + ) + memories.append(memory_entry) + + return SearchMemoryResponse(memories=memories) + + def close(self) -> None: + """Closes the Milvus client connection.""" + self._client.close() diff --git a/src/google/adk/tools/milvus/__init__.py b/src/google/adk/tools/milvus/__init__.py new file mode 100644 index 0000000000..9c868b2006 --- /dev/null +++ b/src/google/adk/tools/milvus/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus Tools (Experimental). + +Milvus Tools provide vector similarity search capabilities using Milvus +as the vector database backend. This module offers: + +1. A MilvusToolset for easy integration with ADK agents. +2. A MilvusVectorStore utility for collection management and data ingestion. +3. Support for any third-party embedding function (Google GenAI, OpenAI, etc.). +""" + +from .milvus_toolset import MilvusToolset + +__all__ = [ + "MilvusToolset", +] diff --git a/src/google/adk/tools/milvus/milvus_tool.py b/src/google/adk/tools/milvus/milvus_tool.py new file mode 100644 index 0000000000..9558a382b8 --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_tool.py @@ -0,0 +1,55 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MilvusTool wraps a function and injects the MilvusVectorStore instance.""" + +from __future__ import annotations + +import inspect +from typing import Any +from typing import Callable + +from typing_extensions import override + +from ..function_tool import FunctionTool +from ..tool_context import ToolContext +from .milvus_vector_store import MilvusVectorStore + + +class MilvusTool(FunctionTool): + """A FunctionTool that injects MilvusVectorStore into the wrapped function. + + The ``vector_store`` parameter is hidden from the LLM function + declaration and automatically injected at runtime. + """ + + def __init__( + self, + func: Callable[..., Any], + *, + vector_store: MilvusVectorStore, + ): + super().__init__(func=func) + self._ignore_params.append("vector_store") + self._vector_store = vector_store + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "vector_store" in signature.parameters: + args_to_call["vector_store"] = self._vector_store + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/milvus/milvus_toolset.py b/src/google/adk/tools/milvus/milvus_toolset.py new file mode 100644 index 0000000000..07cca57dd1 --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_toolset.py @@ -0,0 +1,108 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus Toolset for exposing Milvus vector search to agents.""" + +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Union + +from typing_extensions import override + +from . import search_tool +from ...agents.readonly_context import ReadonlyContext +from ...features import experimental +from ...features import FeatureName +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from .milvus_tool import MilvusTool +from .milvus_vector_store import MilvusVectorStore +from .settings import MilvusToolSettings + +DEFAULT_MILVUS_TOOL_NAME_PREFIX = "milvus" + + +@experimental(FeatureName.MILVUS_TOOLSET) +class MilvusToolset(BaseToolset): + """Milvus Toolset provides tools for vector similarity search in Milvus. + + The tool names are: + - milvus_similarity_search + """ + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + milvus_tool_settings: Optional[MilvusToolSettings] = None, + embedding_fn: Callable[[list[str]], list[list[float]]], + ): + """Initialize the Milvus Toolset. + + Args: + tool_filter: Optional filter to select a subset of tools. + milvus_tool_settings: Milvus tool settings containing vector store + configuration. + embedding_fn: A function that takes a list of texts and returns a + list of embedding vectors. + """ + super().__init__( + tool_filter=tool_filter, + tool_name_prefix=DEFAULT_MILVUS_TOOL_NAME_PREFIX, + ) + self._tool_settings = milvus_tool_settings or MilvusToolSettings() + self._embedding_fn = embedding_fn + self._vector_store: Optional[MilvusVectorStore] = None + + def _get_vector_store(self) -> MilvusVectorStore: + """Lazily creates the MilvusVectorStore instance.""" + if self._vector_store is None: + self._vector_store = MilvusVectorStore( + settings=self._tool_settings, + embedding_fn=self._embedding_fn, + ) + return self._vector_store + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools: list[BaseTool] = [] + + if self._tool_settings.vector_store_settings: + vector_store = self._get_vector_store() + all_tools.append( + MilvusTool( + func=search_tool.similarity_search, + vector_store=vector_store, + ) + ) + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self) -> None: + """Closes the Milvus client connection.""" + if self._vector_store is not None: + self._vector_store.close() + self._vector_store = None diff --git a/src/google/adk/tools/milvus/milvus_vector_store.py b/src/google/adk/tools/milvus/milvus_vector_store.py new file mode 100644 index 0000000000..dd305d783b --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_vector_store.py @@ -0,0 +1,277 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus vector store utility class for data ingestion and retrieval.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Callable +from typing import Optional + +from ...features import experimental +from ...features import FeatureName +from .settings import MilvusToolSettings + +logger = logging.getLogger("google_adk." + __name__) + +try: + from pymilvus import DataType + from pymilvus import MilvusClient +except ImportError: + MilvusClient = None + DataType = None + + +@experimental(FeatureName.MILVUS_VECTOR_STORE) +class MilvusVectorStore: + """A utility class for managing a Milvus vector store. + + This class provides methods for setting up a Milvus collection, + adding content with embeddings, and performing similarity search. + """ + + def __init__( + self, + settings: MilvusToolSettings, + embedding_fn: Callable[[list[str]], list[list[float]]], + ): + """Initializes the MilvusVectorStore with settings and embedding function. + + Args: + settings: The Milvus tool settings containing vector store configuration. + embedding_fn: A function that takes a list of texts and returns a list + of embedding vectors. Signature: ``(list[str]) -> list[list[float]]``. + For example, using Google GenAI:: + + from google.genai import Client + client = Client() + def embedding_fn(texts): + resp = client.models.embed_content( + model="text-embedding-004", contents=texts) + return [list(e.values) for e in resp.embeddings] + + Raises: + ValueError: If vector_store_settings is not set. + ImportError: If pymilvus is not installed. + """ + if not settings.vector_store_settings: + raise ValueError("Milvus vector store settings are not set.") + + if MilvusClient is None: + raise ImportError( + "pymilvus package not found. " + 'Please install with: pip install "google-adk[milvus]"' + ) + + self._settings = settings.vector_store_settings + self._embedding_fn = embedding_fn + + self._client = MilvusClient( + uri=self._settings.uri, + token=self._settings.token, + db_name=self._settings.db_name, + ) + + def setup(self) -> None: + """Creates the Milvus collection and index if they do not exist. + + The collection schema includes: + - A primary key field (auto-generated int64 id). + - A text content field (VARCHAR). + - A vector embedding field (FLOAT_VECTOR). + + The vector index is created with the configured metric type and + index type. + """ + if self._client.has_collection(self._settings.collection_name): + logger.info( + "Collection '%s' already exists, skipping setup.", + self._settings.collection_name, + ) + return + + schema = self._client.create_schema(auto_id=True, enable_dynamic_field=True) + schema.add_field( + field_name=self._settings.primary_field, + datatype=DataType.INT64, + is_primary=True, + ) + schema.add_field( + field_name=self._settings.content_field, + datatype=DataType.VARCHAR, + max_length=65535, + ) + schema.add_field( + field_name=self._settings.embedding_field, + datatype=DataType.FLOAT_VECTOR, + dim=self._settings.dimension, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name=self._settings.embedding_field, + index_type=self._settings.index_type, + metric_type=self._settings.metric_type, + ) + + self._client.create_collection( + collection_name=self._settings.collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + "Created collection '%s' with dimension=%d, metric=%s, index=%s.", + self._settings.collection_name, + self._settings.dimension, + self._settings.metric_type, + self._settings.index_type, + ) + + def add_contents( + self, + contents: list[str], + *, + additional_fields: Optional[list[dict]] = None, + batch_size: int = 200, + ) -> None: + """Adds text contents to the vector store. + + Performs batch embedding generation and insertion into the Milvus + collection. + + Args: + contents: An iterable of text contents to add. + additional_fields: Optional list of dicts with extra field values + for each content row. + batch_size: Maximum number of items per batch. Defaults to 200. + + Raises: + ValueError: If additional_fields length does not match contents length. + """ + if additional_fields and len(additional_fields) != len(contents): + raise ValueError( + "The number of additional_fields must match the number of contents." + ) + + total_rows = 0 + for i in range(0, len(contents), batch_size): + batch_contents = contents[i : i + batch_size] + batch_extra = ( + additional_fields[i : i + batch_size] if additional_fields else None + ) + + logger.debug( + "Embedding batch %d to %d (size: %d)...", + i, + i + len(batch_contents), + len(batch_contents), + ) + embeddings = self._embedding_fn(batch_contents) + + data = [] + for j, (content, embedding) in enumerate(zip(batch_contents, embeddings)): + row = { + self._settings.content_field: content, + self._settings.embedding_field: embedding, + } + if batch_extra and j < len(batch_extra): + row.update(batch_extra[j]) + data.append(row) + + self._client.insert( + collection_name=self._settings.collection_name, + data=data, + ) + total_rows += len(data) + + logger.info( + "Added %d contents to collection '%s'.", + total_rows, + self._settings.collection_name, + ) + + async def add_contents_async( + self, + contents: list[str], + *, + additional_fields: Optional[list[dict]] = None, + batch_size: int = 200, + ) -> None: + """Asynchronously adds text contents to the vector store. + + Args: + contents: An iterable of text contents to add. + additional_fields: Optional list of dicts with extra field values + for each content row. + batch_size: Maximum number of items per batch. Defaults to 200. + """ + await asyncio.to_thread( + self.add_contents, + contents, + additional_fields=additional_fields, + batch_size=batch_size, + ) + + def search( + self, + query: str, + *, + top_k: Optional[int] = None, + filter_expr: Optional[str] = None, + ) -> list[dict]: + """Performs vector similarity search. + + Args: + query: The search query text. It will be embedded using the + configured embedding function. + top_k: Number of results to return. Overrides the default from + settings if provided. + filter_expr: Optional Milvus filter expression for pre-filtering. + + Returns: + A list of dicts, each containing the matched content and any + configured output fields, along with the distance score. + """ + top_k = top_k or self._settings.top_k + + query_embedding = self._embedding_fn([query])[0] + + output_fields = self._settings.output_fields or [ + self._settings.content_field + ] + + search_params = {"metric_type": self._settings.metric_type} + + results = self._client.search( + collection_name=self._settings.collection_name, + data=[query_embedding], + limit=top_k, + output_fields=output_fields, + search_params=search_params, + filter=filter_expr or "", + ) + + if not results or not results[0]: + return [] + + return [ + {**hit["entity"], "distance": hit["distance"]} for hit in results[0] + ] + + def close(self) -> None: + """Closes the Milvus client connection.""" + self._client.close() diff --git a/src/google/adk/tools/milvus/search_tool.py b/src/google/adk/tools/milvus/search_tool.py new file mode 100644 index 0000000000..08c1d511b7 --- /dev/null +++ b/src/google/adk/tools/milvus/search_tool.py @@ -0,0 +1,66 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus similarity search tool function for agent use.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from .milvus_vector_store import MilvusVectorStore + +logger = logging.getLogger("google_adk." + __name__) + + +def similarity_search( + query: str, + vector_store: MilvusVectorStore, + filter_expr: Optional[str] = None, +) -> str: + # fmt: off + """Search the knowledge base for information relevant to the user's query. + + Use this tool when you need to look up facts, find relevant documents, + or answer questions that require knowledge from the vector database. + The tool performs semantic similarity search — it finds content whose + meaning is closest to the query, not just keyword matches. + + Args: + query (str): A natural-language search query describing what + information you are looking for. Be specific and descriptive + for better results. For example, use "How does Milvus handle + vector indexing?" rather than just "Milvus". + filter_expr (str): An optional filter expression to narrow down + search results before ranking by similarity. Uses Milvus + boolean expression syntax, for example: + ``category == "tech"`` or ``year > 2023``. Leave empty if no + filtering is needed. + + Returns: + str: A JSON-formatted string containing the search results. Each + result includes the matched content and a distance score + indicating similarity (lower distance means higher similarity + for L2/EUCLIDEAN, higher score means higher similarity for + COSINE/IP). Returns "No matching results found." if no + relevant content exists. + """ + # fmt: on + results = vector_store.search(query=query, filter_expr=filter_expr) + + if not results: + return "No matching results found." + + return json.dumps(results, ensure_ascii=False, default=str) diff --git a/src/google/adk/tools/milvus/settings.py b/src/google/adk/tools/milvus/settings.py new file mode 100644 index 0000000000..7691bea7e2 --- /dev/null +++ b/src/google/adk/tools/milvus/settings.py @@ -0,0 +1,98 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Settings for Milvus vector store and toolset.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import model_validator + + +class MilvusVectorStoreSettings(BaseModel): + """Settings for Milvus vector store. + + This is used for vector similarity search in a Milvus collection. + Provide the collection and connection settings to use with the + ``similarity_search`` tool. + """ + + uri: str = "http://localhost:19530" + """The Milvus server URI. + + Can be a remote address like ``http://localhost:19530`` or a local + file path for Milvus Lite (e.g. ``./milvus.db``). + """ + + token: Optional[str] = None + """Optional authentication token (e.g. for Zilliz Cloud).""" + + db_name: str = "default" + """The Milvus database name.""" + + collection_name: str + """Required. The name of the Milvus collection.""" + + dimension: int = 768 + """The dimension of the embedding vectors.""" + + metric_type: str = "COSINE" + """The distance metric for similarity search. + + Supported values: ``COSINE``, ``L2``, ``IP``. + """ + + index_type: str = "AUTOINDEX" + """The index type for the vector field. + + Supported values: ``AUTOINDEX``, ``IVF_FLAT``, ``HNSW``, etc. + """ + + content_field: str = "content" + """The name of the text content field in the collection.""" + + embedding_field: str = "embedding" + """The name of the vector embedding field in the collection.""" + + primary_field: str = "id" + """The name of the primary key field in the collection.""" + + top_k: int = 5 + """The default number of results to return from similarity search.""" + + output_fields: Optional[list[str]] = None + """Optional additional fields to return in search results. + + If ``None``, only the ``content_field`` is returned. + """ + + @model_validator(mode="after") + def _validate_settings(self): + """Validate the vector store settings.""" + if self.dimension <= 0: + raise ValueError( + f"Invalid dimension: {self.dimension}. Must be positive." + ) + if self.top_k <= 0: + raise ValueError(f"Invalid top_k: {self.top_k}. Must be positive.") + return self + + +class MilvusToolSettings(BaseModel): + """Settings for Milvus toolset.""" + + vector_store_settings: Optional[MilvusVectorStoreSettings] = None + """Settings for Milvus vector store and vector similarity search.""" diff --git a/tests/unittests/memory/test_milvus_memory_service.py b/tests/unittests/memory/test_milvus_memory_service.py new file mode 100644 index 0000000000..3fb84e8693 --- /dev/null +++ b/tests/unittests/memory/test_milvus_memory_service.py @@ -0,0 +1,389 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.memory.milvus_memory_service import MilvusMemoryService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +DIMENSION = 4 +APP_NAME = "test-app" +USER_ID = "test-user" + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + """A mock embedding function returning fixed-dimension vectors.""" + return [[0.1] * DIMENSION for _ in texts] + + +def _make_session( + session_id: str = "session-1", + app_name: str = APP_NAME, + user_id: str = USER_ID, + events: list[Event] | None = None, +) -> Session: + if events is None: + events = [ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=1000.0, + content=types.Content( + parts=[types.Part(text="Hello, I like Python.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=1001.0, + content=types.Content( + parts=[types.Part(text="Python is a great language!")] + ), + ), + ] + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + last_update_time=2000, + events=events, + ) + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_init_success(mock_milvus_client_cls): + """Test successful initialization.""" + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + uri="http://localhost:19530", + collection_name="test_memory", + dimension=DIMENSION, + ) + mock_milvus_client_cls.assert_called_once_with( + uri="http://localhost:19530", + token=None, + db_name="default", + ) + assert service._embedding_fn is _mock_embedding_fn + assert service._collection_name == "test_memory" + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_creates_when_missing(mock_milvus_client_cls): + """Test that _ensure_collection creates collection when it doesn't exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + + mock_client.has_collection.assert_called_once_with("adk_memory") + mock_client.create_collection.assert_called_once() + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_skips_existing(mock_milvus_client_cls): + """Test that _ensure_collection skips if collection exists.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + + mock_client.has_collection.assert_called_once_with("adk_memory") + mock_client.create_collection.assert_not_called() + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_idempotent(mock_milvus_client_cls): + """Test that _ensure_collection only checks once.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + service._ensure_collection() + + # has_collection should only be called once due to _collection_ready flag. + mock_client.has_collection.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_to_memory(mock_milvus_client_cls): + """Test adding a session with events to memory.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [] # No existing events. + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + mock_client.insert.assert_called_once() + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert len(data) == 2 + assert data[0]["app_name"] == APP_NAME + assert data[0]["user_id"] == USER_ID + assert data[0]["session_id"] == "session-1" + assert data[0]["author"] == "user" + assert data[0]["content"] == "Hello, I like Python." + assert data[0]["timestamp"] == 1000.0 + assert len(data[0]["embedding"]) == DIMENSION + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_skips_empty_events(mock_milvus_client_cls): + """Test that events without text content are skipped.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [] + + session = _make_session(events=[ + Event( + id="event-empty", + invocation_id="inv-1", + author="user", + timestamp=1000.0, + # No content at all. + ), + Event( + id="event-func", + invocation_id="inv-2", + author="model", + timestamp=1001.0, + content=types.Content( + parts=[types.Part(function_call=types.FunctionCall( + name="test_fn", args={} + ))] + ), + ), + Event( + id="event-text", + invocation_id="inv-3", + author="user", + timestamp=1002.0, + content=types.Content( + parts=[types.Part(text="Real message")] + ), + ), + ]) + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + await service.add_session_to_memory(session) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert len(data) == 1 + assert data[0]["content"] == "Real message" + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_deduplication(mock_milvus_client_cls): + """Test that events already in Milvus are not re-inserted.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + # Simulate one event already existing. + mock_client.query.return_value = [{"timestamp": 1000.0}] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + # Only the second event (timestamp=1001.0) should be inserted. + assert len(data) == 1 + assert data[0]["timestamp"] == 1001.0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_all_duplicated(mock_milvus_client_cls): + """Test that no insert happens when all events already exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [ + {"timestamp": 1000.0}, + {"timestamp": 1001.0}, + ] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + mock_client.insert.assert_not_called() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory(mock_milvus_client_cls): + """Test similarity search returns MemoryEntry objects.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[ + { + "entity": { + "content": "I like Python.", + "author": "user", + "timestamp": 1000.0, + }, + "distance": 0.95, + }, + { + "entity": { + "content": "Python is great!", + "author": "model", + "timestamp": 1001.0, + }, + "distance": 0.90, + }, + ]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + response = await service.search_memory( + app_name=APP_NAME, + user_id=USER_ID, + query="Python", + ) + + assert len(response.memories) == 2 + assert response.memories[0].author == "user" + assert response.memories[0].content.parts[0].text == "I like Python." + assert response.memories[0].content.role == "user" + assert response.memories[1].author == "model" + assert response.memories[1].content.role == "model" + + # Verify search was called with correct filter. + call_args = mock_client.search.call_args + assert f'app_name == "{APP_NAME}"' in call_args.kwargs["filter"] + assert f'user_id == "{USER_ID}"' in call_args.kwargs["filter"] + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory_empty_results(mock_milvus_client_cls): + """Test search returning no results.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + response = await service.search_memory( + app_name=APP_NAME, + user_id=USER_ID, + query="something", + ) + + assert len(response.memories) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory_user_scoping(mock_milvus_client_cls): + """Test that search scopes by app_name and user_id.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + await service.search_memory( + app_name="app-A", + user_id="user-B", + query="test", + ) + + call_args = mock_client.search.call_args + filter_expr = call_args.kwargs["filter"] + assert 'app_name == "app-A"' in filter_expr + assert 'user_id == "user-B"' in filter_expr + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_close(mock_milvus_client_cls): + """Test closing the memory service.""" + mock_client = mock_milvus_client_cls.return_value + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service.close() + + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_no_events(mock_milvus_client_cls): + """Test adding a session with no events does nothing.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session(events=[]) + + await service.add_session_to_memory(session) + + mock_client.insert.assert_not_called() diff --git a/tests/unittests/tools/milvus/__init__.py b/tests/unittests/tools/milvus/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/tests/unittests/tools/milvus/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/milvus/test_milvus_toolset.py b/tests/unittests/tools/milvus/test_milvus_toolset.py new file mode 100644 index 0000000000..b7ef5c540a --- /dev/null +++ b/tests/unittests/tools/milvus/test_milvus_toolset.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.milvus.milvus_tool import MilvusTool +from google.adk.tools.milvus.milvus_toolset import MilvusToolset +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + +DIMENSION = 4 + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + return [[0.1] * DIMENSION for _ in texts] + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_returns_search_tool(mock_milvus_client_cls): + """Test that MilvusToolset returns a similarity_search tool.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + tools = await toolset.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], MilvusTool) + assert tools[0].name == "similarity_search" + + +@pytest.mark.asyncio +async def test_toolset_no_vector_store_settings(): + """Test that MilvusToolset returns no tools without vector store settings.""" + toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings(), + embedding_fn=_mock_embedding_fn, + ) + + tools = await toolset.get_tools() + assert len(tools) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_with_filter(mock_milvus_client_cls): + """Test that MilvusToolset respects tool_filter.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + tool_filter=["nonexistent_tool"], + ) + + tools = await toolset.get_tools() + assert len(tools) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_with_matching_filter(mock_milvus_client_cls): + """Test that tool_filter includes matching tools.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + tool_filter=["similarity_search"], + ) + + tools = await toolset.get_tools() + assert len(tools) == 1 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_close(mock_milvus_client_cls): + """Test that close shuts down the vector store.""" + mock_client = mock_milvus_client_cls.return_value + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + # Force vector store creation + await toolset.get_tools() + + await toolset.close() + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_close_without_init(mock_milvus_client_cls): + """Test that close works even if vector store was never created.""" + toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings(), + embedding_fn=_mock_embedding_fn, + ) + # Should not raise + await toolset.close() diff --git a/tests/unittests/tools/milvus/test_milvus_vector_store.py b/tests/unittests/tools/milvus/test_milvus_vector_store.py new file mode 100644 index 0000000000..7efee5a7be --- /dev/null +++ b/tests/unittests/tools/milvus/test_milvus_vector_store.py @@ -0,0 +1,285 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + +DIMENSION = 4 + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + """A mock embedding function returning fixed-dimension vectors.""" + return [[0.1] * DIMENSION for _ in texts] + + +@pytest.fixture +def vector_store_settings(): + return MilvusVectorStoreSettings( + collection_name="test_collection", + dimension=DIMENSION, + metric_type="COSINE", + ) + + +@pytest.fixture +def tool_settings(vector_store_settings): + return MilvusToolSettings(vector_store_settings=vector_store_settings) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_init_success(mock_milvus_client_cls, tool_settings): + """Test successful initialization of MilvusVectorStore.""" + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + mock_milvus_client_cls.assert_called_once_with( + uri="http://localhost:19530", + token=None, + db_name="default", + ) + assert store._embedding_fn is _mock_embedding_fn + + +def test_init_missing_vector_store_settings(): + """Test that missing vector_store_settings raises ValueError.""" + settings = MilvusToolSettings() + with pytest.raises(ValueError, match="not set"): + MilvusVectorStore( + settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_setup_creates_collection(mock_milvus_client_cls, tool_settings): + """Test that setup creates collection when it doesn't exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.setup() + + mock_client.has_collection.assert_called_once_with("test_collection") + mock_client.create_collection.assert_called_once() + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_setup_skips_existing_collection(mock_milvus_client_cls, tool_settings): + """Test that setup skips when collection already exists.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + store.setup() + + mock_client.has_collection.assert_called_once_with("test_collection") + mock_client.create_collection.assert_not_called() + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents(mock_milvus_client_cls, tool_settings): + """Test adding contents to the vector store.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = ["hello world", "foo bar"] + store.add_contents(contents) + + mock_client.insert.assert_called_once() + call_args = mock_client.insert.call_args + assert call_args.kwargs["collection_name"] == "test_collection" + data = call_args.kwargs["data"] + assert len(data) == 2 + assert data[0]["content"] == "hello world" + assert data[1]["content"] == "foo bar" + assert len(data[0]["embedding"]) == DIMENSION + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_with_batching(mock_milvus_client_cls, tool_settings): + """Test that add_contents batches correctly.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = [f"doc_{i}" for i in range(5)] + store.add_contents(contents, batch_size=2) + + # 5 items with batch_size=2 => 3 batches (2, 2, 1) + assert mock_client.insert.call_count == 3 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_with_additional_fields( + mock_milvus_client_cls, tool_settings +): + """Test adding contents with additional fields.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = ["hello"] + additional = [{"title": "greeting"}] + store.add_contents(contents, additional_fields=additional) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert data[0]["title"] == "greeting" + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_mismatched_additional_fields( + mock_milvus_client_cls, tool_settings +): + """Test that mismatched additional_fields raises ValueError.""" + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + with pytest.raises(ValueError, match="must match"): + store.add_contents( + ["doc1", "doc2"], + additional_fields=[{"title": "only_one"}], + ) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search(mock_milvus_client_cls, tool_settings): + """Test similarity search.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[ + {"entity": {"content": "result1"}, "distance": 0.9}, + {"entity": {"content": "result2"}, "distance": 0.8}, + ]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + results = store.search("test query") + + mock_client.search.assert_called_once() + call_args = mock_client.search.call_args + assert call_args.kwargs["collection_name"] == "test_collection" + assert call_args.kwargs["limit"] == 5 + assert call_args.kwargs["output_fields"] == ["content"] + + assert len(results) == 2 + assert results[0]["content"] == "result1" + assert results[0]["distance"] == 0.9 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_with_custom_top_k(mock_milvus_client_cls, tool_settings): + """Test search with custom top_k.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.search("test", top_k=10) + + call_args = mock_client.search.call_args + assert call_args.kwargs["limit"] == 10 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_with_filter(mock_milvus_client_cls, tool_settings): + """Test search with filter expression.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.search("test", filter_expr='category == "tech"') + + call_args = mock_client.search.call_args + assert call_args.kwargs["filter"] == 'category == "tech"' + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_empty_results(mock_milvus_client_cls, tool_settings): + """Test search returning empty results.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + results = store.search("test") + assert results == [] + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_close(mock_milvus_client_cls, tool_settings): + """Test closing the vector store.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + store.close() + + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_add_contents_async(mock_milvus_client_cls, tool_settings): + """Test async add_contents delegates to sync version.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + await store.add_contents_async(["hello", "world"]) + + mock_client.insert.assert_called_once() diff --git a/tests/unittests/tools/milvus/test_search_tool.py b/tests/unittests/tools/milvus/test_search_tool.py new file mode 100644 index 0000000000..97b4d3bd36 --- /dev/null +++ b/tests/unittests/tools/milvus/test_search_tool.py @@ -0,0 +1,76 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from unittest import mock + +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.search_tool import similarity_search + + +def test_similarity_search_with_results(): + """Test similarity_search returns JSON formatted results.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [ + {"content": "result1", "distance": 0.9}, + {"content": "result2", "distance": 0.8}, + ] + + result = similarity_search( + query="test query", + vector_store=mock_store, + ) + + mock_store.search.assert_called_once_with( + query="test query", filter_expr=None + ) + parsed = json.loads(result) + assert len(parsed) == 2 + assert parsed[0]["content"] == "result1" + + +def test_similarity_search_no_results(): + """Test similarity_search with no matching results.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [] + + result = similarity_search( + query="unknown query", + vector_store=mock_store, + ) + + assert result == "No matching results found." + + +def test_similarity_search_with_filter(): + """Test similarity_search with filter expression.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [ + {"content": "filtered result", "distance": 0.95}, + ] + + result = similarity_search( + query="test", + vector_store=mock_store, + filter_expr='category == "tech"', + ) + + mock_store.search.assert_called_once_with( + query="test", filter_expr='category == "tech"' + ) + parsed = json.loads(result) + assert len(parsed) == 1 + assert parsed[0]["content"] == "filtered result" diff --git a/tests/unittests/tools/milvus/test_settings.py b/tests/unittests/tools/milvus/test_settings.py new file mode 100644 index 0000000000..e0050549e1 --- /dev/null +++ b/tests/unittests/tools/milvus/test_settings.py @@ -0,0 +1,105 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + + +def test_default_settings(): + """Test MilvusVectorStoreSettings with default values.""" + settings = MilvusVectorStoreSettings(collection_name="test_collection") + assert settings.uri == "http://localhost:19530" + assert settings.token is None + assert settings.db_name == "default" + assert settings.collection_name == "test_collection" + assert settings.dimension == 768 + assert settings.metric_type == "COSINE" + assert settings.index_type == "AUTOINDEX" + assert settings.content_field == "content" + assert settings.embedding_field == "embedding" + assert settings.primary_field == "id" + assert settings.top_k == 5 + assert settings.output_fields is None + + +def test_custom_settings(): + """Test MilvusVectorStoreSettings with custom values.""" + settings = MilvusVectorStoreSettings( + uri="http://milvus:19530", + token="test_token", + db_name="mydb", + collection_name="docs", + dimension=384, + metric_type="L2", + index_type="HNSW", + content_field="text", + embedding_field="vec", + primary_field="pk", + top_k=10, + output_fields=["text", "title"], + ) + assert settings.uri == "http://milvus:19530" + assert settings.token == "test_token" + assert settings.db_name == "mydb" + assert settings.dimension == 384 + assert settings.metric_type == "L2" + assert settings.index_type == "HNSW" + assert settings.content_field == "text" + assert settings.embedding_field == "vec" + assert settings.top_k == 10 + assert settings.output_fields == ["text", "title"] + + +def test_invalid_dimension(): + """Test that invalid dimension raises ValueError.""" + with pytest.raises(ValueError, match="Invalid dimension"): + MilvusVectorStoreSettings( + collection_name="test", + dimension=0, + ) + + +def test_invalid_dimension_negative(): + """Test that negative dimension raises ValueError.""" + with pytest.raises(ValueError, match="Invalid dimension"): + MilvusVectorStoreSettings( + collection_name="test", + dimension=-1, + ) + + +def test_invalid_top_k(): + """Test that invalid top_k raises ValueError.""" + with pytest.raises(ValueError, match="Invalid top_k"): + MilvusVectorStoreSettings( + collection_name="test", + top_k=0, + ) + + +def test_milvus_tool_settings_default(): + """Test MilvusToolSettings with default values.""" + settings = MilvusToolSettings() + assert settings.vector_store_settings is None + + +def test_milvus_tool_settings_with_vector_store(): + """Test MilvusToolSettings with vector store settings.""" + vs_settings = MilvusVectorStoreSettings(collection_name="test") + settings = MilvusToolSettings(vector_store_settings=vs_settings) + assert settings.vector_store_settings is not None + assert settings.vector_store_settings.collection_name == "test"