From 1e7e5cae0eeec6365c38572f81a3600d048c720b Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Sun, 8 Feb 2026 14:19:09 +0000 Subject: [PATCH 1/3] feat: Add Milvus vector store integration for RAG Add Milvus as a vector database option for knowledge base RAG workflows. This follows the existing Spanner integration pattern with BaseToolset, providing data ingestion, similarity search, and a ready-to-use toolset for LLM agents. - MilvusVectorStoreSettings / MilvusToolSettings (Pydantic config) - MilvusVectorStore: setup, add_contents, search, close - MilvusTool: FunctionTool subclass injecting vector_store param - MilvusToolset: BaseToolset providing similarity_search tool - Registered MILVUS_TOOLSET / MILVUS_VECTOR_STORE as experimental features - Added pymilvus>=2.5.0 to test, extensions, and new milvus extras - 30 unit tests (all passing) - Sample milvus_rag_agent under contributing/samples/ --- .../samples/milvus_rag_agent/README.md | 86 ++++++ .../samples/milvus_rag_agent/__init__.py | 15 + .../samples/milvus_rag_agent/agent.py | 85 ++++++ pyproject.toml | 6 + src/google/adk/features/_feature_registry.py | 8 + src/google/adk/tools/milvus/__init__.py | 29 ++ src/google/adk/tools/milvus/milvus_tool.py | 55 ++++ src/google/adk/tools/milvus/milvus_toolset.py | 108 +++++++ .../adk/tools/milvus/milvus_vector_store.py | 277 +++++++++++++++++ src/google/adk/tools/milvus/search_tool.py | 52 ++++ src/google/adk/tools/milvus/settings.py | 98 ++++++ tests/unittests/tools/milvus/__init__.py | 13 + .../tools/milvus/test_milvus_toolset.py | 137 +++++++++ .../tools/milvus/test_milvus_vector_store.py | 285 ++++++++++++++++++ .../tools/milvus/test_search_tool.py | 76 +++++ tests/unittests/tools/milvus/test_settings.py | 105 +++++++ 16 files changed, 1435 insertions(+) create mode 100644 contributing/samples/milvus_rag_agent/README.md create mode 100644 contributing/samples/milvus_rag_agent/__init__.py create mode 100644 contributing/samples/milvus_rag_agent/agent.py create mode 100644 src/google/adk/tools/milvus/__init__.py create mode 100644 src/google/adk/tools/milvus/milvus_tool.py create mode 100644 src/google/adk/tools/milvus/milvus_toolset.py create mode 100644 src/google/adk/tools/milvus/milvus_vector_store.py create mode 100644 src/google/adk/tools/milvus/search_tool.py create mode 100644 src/google/adk/tools/milvus/settings.py create mode 100644 tests/unittests/tools/milvus/__init__.py create mode 100644 tests/unittests/tools/milvus/test_milvus_toolset.py create mode 100644 tests/unittests/tools/milvus/test_milvus_vector_store.py create mode 100644 tests/unittests/tools/milvus/test_search_tool.py create mode 100644 tests/unittests/tools/milvus/test_settings.py diff --git a/contributing/samples/milvus_rag_agent/README.md b/contributing/samples/milvus_rag_agent/README.md new file mode 100644 index 0000000000..a6437c6c73 --- /dev/null +++ b/contributing/samples/milvus_rag_agent/README.md @@ -0,0 +1,86 @@ +# Milvus RAG Agent Sample + +This sample demonstrates how to build a knowledge base agent using +[Milvus](https://milvus.io/) as the vector database for +retrieval-augmented generation (RAG) with Google ADK. + +## Prerequisites + +1. A running Milvus instance, or use Milvus Lite (local file path). +2. A Google GenAI API key for embedding generation. +3. Install dependencies: + +```bash +pip install "google-adk[milvus]" +``` + +## Supported Milvus Backends + +The integration works with all three Milvus deployment modes by changing +the `uri` (and optionally `token`) setting: + +| Backend | `uri` | `token` | +|---------|-------|---------| +| **Milvus Lite** (local, no server needed) | `./milvus.db` | — | +| **Milvus Server** (self-hosted) | `http://localhost:19530` | — | +| **Zilliz Cloud** (fully managed) | `https://in01-xxx.serverless.gcp-us-west1.cloud.zilliz.com` | `your-api-key` | + +> For Milvus Lite, install the extra package: `pip install "pymilvus[milvus_lite]"`. + +## Setup + +### Environment Variables + +```bash +export GOOGLE_API_KEY="your-google-api-key" + +# Pick one of the following: +export MILVUS_URI="./milvus.db" # Milvus Lite +export MILVUS_URI="http://localhost:19530" # Milvus Server +export MILVUS_URI="https://in01-xxx.cloud.zilliz.com" # Zilliz Cloud + +export MILVUS_COLLECTION="knowledge_base" +# Required for Zilliz Cloud only: +# export MILVUS_TOKEN="your-api-key" +``` + +### Data Ingestion + +Before running the agent, you need to populate the Milvus collection +with your knowledge base data: + +```python +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +from google.genai import Client + +# Define your embedding function (example using Google GenAI). +genai_client = Client() + +def embedding_fn(texts): + resp = genai_client.models.embed_content( + model="text-embedding-004", contents=texts) + return [list(e.values) for e in resp.embeddings] + +settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + uri="http://localhost:19530", + collection_name="knowledge_base", + dimension=768, + ), +) + +store = MilvusVectorStore(settings=settings, embedding_fn=embedding_fn) +store.setup() +store.add_contents([ + "Your document text here...", + "Another document...", +]) +``` + +## Run + +```bash +adk run contributing/samples/milvus_rag_agent +``` diff --git a/contributing/samples/milvus_rag_agent/__init__.py b/contributing/samples/milvus_rag_agent/__init__.py new file mode 100644 index 0000000000..4015e47d6e --- /dev/null +++ b/contributing/samples/milvus_rag_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/milvus_rag_agent/agent.py b/contributing/samples/milvus_rag_agent/agent.py new file mode 100644 index 0000000000..9c236f287c --- /dev/null +++ b/contributing/samples/milvus_rag_agent/agent.py @@ -0,0 +1,85 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample Milvus RAG Agent. + +This sample demonstrates how to build a knowledge base agent using Milvus +as the vector database for retrieval-augmented generation (RAG). + +Prerequisites: + 1. A running Milvus instance (or use Milvus Lite with a local file path). + 2. A Google GenAI API key set in the GOOGLE_API_KEY environment variable. + 3. Install dependencies: pip install "google-adk[milvus]" + +Usage: + adk run contributing/samples/milvus_rag_agent +""" + +import os + +from dotenv import load_dotenv +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.milvus.milvus_toolset import MilvusToolset +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +from google.genai import Client + +load_dotenv() + +# --- Embedding function using Google GenAI --- +genai_client = Client() + + +def embedding_fn(texts: list[str]) -> list[list[float]]: + response = genai_client.models.embed_content( + model="text-embedding-004", + contents=texts, + ) + return [list(e.values) for e in response.embeddings] + + +# --- Milvus vector store settings --- +# Replace these with your own Milvus connection and collection settings. +vector_store_settings = MilvusVectorStoreSettings( + # Use a remote Milvus instance or a local Milvus Lite file path. + uri=os.environ.get("MILVUS_URI", "http://localhost:19530"), + token=os.environ.get("MILVUS_TOKEN", None), + collection_name=os.environ.get("MILVUS_COLLECTION", "knowledge_base"), + dimension=768, + metric_type="COSINE", + top_k=5, +) + +# --- Milvus toolset --- +milvus_toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings( + vector_store_settings=vector_store_settings, + ), + embedding_fn=embedding_fn, + tool_filter=["similarity_search"], +) + +# --- Agent definition --- +root_agent = LlmAgent( + model="gemini-2.5-flash", + name="milvus_knowledge_agent", + description="Agent that answers questions using a Milvus knowledge base.", + instruction=""" + You are a helpful assistant with access to a knowledge base. + 1. Always use the `similarity_search` tool to find relevant information. + 2. Present the search results naturally in your response. + 3. If no results are found, say you don't know. + """, + tools=[milvus_toolset], +) diff --git a/pyproject.toml b/pyproject.toml index 2564e8baf1..05ee68b01b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ test = [ "litellm>=1.75.5, <1.80.17", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM + "pymilvus>=2.5.0", # For Milvus vector store tests "opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0", "pypika>=0.50.0", # For crewai->chromadb dependency "pytest-asyncio>=0.25.0", @@ -159,11 +160,16 @@ extensions = [ "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. + "pymilvus>=2.5.0", # For Milvus vector store integration. "lxml>=5.3.0", # For load_web_page tool. "pypika>=0.50.0", # For crewai->chromadb dependency "toolbox-adk>=0.5.7, <0.6.0", # For tools.toolbox_toolset.ToolboxToolset ] +milvus = [ + "pymilvus>=2.5.0", +] + otel-gcp = ["opentelemetry-instrumentation-google-genai>=0.6b0, <1.0.0"] toolbox = ["toolbox-adk>=0.5.7, <0.6.0"] diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 154c77caf6..5021e321eb 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -38,6 +38,8 @@ class FeatureName(str, Enum): GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG" GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" + MILVUS_TOOLSET = "MILVUS_TOOLSET" + MILVUS_VECTOR_STORE = "MILVUS_VECTOR_STORE" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG" PUBSUB_TOOLSET = "PUBSUB_TOOLSET" @@ -114,6 +116,12 @@ class FeatureConfig: FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig( FeatureStage.WIP, default_on=False ), + FeatureName.MILVUS_TOOLSET: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), + FeatureName.MILVUS_VECTOR_STORE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/milvus/__init__.py b/src/google/adk/tools/milvus/__init__.py new file mode 100644 index 0000000000..9c868b2006 --- /dev/null +++ b/src/google/adk/tools/milvus/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus Tools (Experimental). + +Milvus Tools provide vector similarity search capabilities using Milvus +as the vector database backend. This module offers: + +1. A MilvusToolset for easy integration with ADK agents. +2. A MilvusVectorStore utility for collection management and data ingestion. +3. Support for any third-party embedding function (Google GenAI, OpenAI, etc.). +""" + +from .milvus_toolset import MilvusToolset + +__all__ = [ + "MilvusToolset", +] diff --git a/src/google/adk/tools/milvus/milvus_tool.py b/src/google/adk/tools/milvus/milvus_tool.py new file mode 100644 index 0000000000..9558a382b8 --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_tool.py @@ -0,0 +1,55 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MilvusTool wraps a function and injects the MilvusVectorStore instance.""" + +from __future__ import annotations + +import inspect +from typing import Any +from typing import Callable + +from typing_extensions import override + +from ..function_tool import FunctionTool +from ..tool_context import ToolContext +from .milvus_vector_store import MilvusVectorStore + + +class MilvusTool(FunctionTool): + """A FunctionTool that injects MilvusVectorStore into the wrapped function. + + The ``vector_store`` parameter is hidden from the LLM function + declaration and automatically injected at runtime. + """ + + def __init__( + self, + func: Callable[..., Any], + *, + vector_store: MilvusVectorStore, + ): + super().__init__(func=func) + self._ignore_params.append("vector_store") + self._vector_store = vector_store + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "vector_store" in signature.parameters: + args_to_call["vector_store"] = self._vector_store + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/milvus/milvus_toolset.py b/src/google/adk/tools/milvus/milvus_toolset.py new file mode 100644 index 0000000000..07cca57dd1 --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_toolset.py @@ -0,0 +1,108 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus Toolset for exposing Milvus vector search to agents.""" + +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Union + +from typing_extensions import override + +from . import search_tool +from ...agents.readonly_context import ReadonlyContext +from ...features import experimental +from ...features import FeatureName +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from .milvus_tool import MilvusTool +from .milvus_vector_store import MilvusVectorStore +from .settings import MilvusToolSettings + +DEFAULT_MILVUS_TOOL_NAME_PREFIX = "milvus" + + +@experimental(FeatureName.MILVUS_TOOLSET) +class MilvusToolset(BaseToolset): + """Milvus Toolset provides tools for vector similarity search in Milvus. + + The tool names are: + - milvus_similarity_search + """ + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + milvus_tool_settings: Optional[MilvusToolSettings] = None, + embedding_fn: Callable[[list[str]], list[list[float]]], + ): + """Initialize the Milvus Toolset. + + Args: + tool_filter: Optional filter to select a subset of tools. + milvus_tool_settings: Milvus tool settings containing vector store + configuration. + embedding_fn: A function that takes a list of texts and returns a + list of embedding vectors. + """ + super().__init__( + tool_filter=tool_filter, + tool_name_prefix=DEFAULT_MILVUS_TOOL_NAME_PREFIX, + ) + self._tool_settings = milvus_tool_settings or MilvusToolSettings() + self._embedding_fn = embedding_fn + self._vector_store: Optional[MilvusVectorStore] = None + + def _get_vector_store(self) -> MilvusVectorStore: + """Lazily creates the MilvusVectorStore instance.""" + if self._vector_store is None: + self._vector_store = MilvusVectorStore( + settings=self._tool_settings, + embedding_fn=self._embedding_fn, + ) + return self._vector_store + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools: list[BaseTool] = [] + + if self._tool_settings.vector_store_settings: + vector_store = self._get_vector_store() + all_tools.append( + MilvusTool( + func=search_tool.similarity_search, + vector_store=vector_store, + ) + ) + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self) -> None: + """Closes the Milvus client connection.""" + if self._vector_store is not None: + self._vector_store.close() + self._vector_store = None diff --git a/src/google/adk/tools/milvus/milvus_vector_store.py b/src/google/adk/tools/milvus/milvus_vector_store.py new file mode 100644 index 0000000000..dd305d783b --- /dev/null +++ b/src/google/adk/tools/milvus/milvus_vector_store.py @@ -0,0 +1,277 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus vector store utility class for data ingestion and retrieval.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Callable +from typing import Optional + +from ...features import experimental +from ...features import FeatureName +from .settings import MilvusToolSettings + +logger = logging.getLogger("google_adk." + __name__) + +try: + from pymilvus import DataType + from pymilvus import MilvusClient +except ImportError: + MilvusClient = None + DataType = None + + +@experimental(FeatureName.MILVUS_VECTOR_STORE) +class MilvusVectorStore: + """A utility class for managing a Milvus vector store. + + This class provides methods for setting up a Milvus collection, + adding content with embeddings, and performing similarity search. + """ + + def __init__( + self, + settings: MilvusToolSettings, + embedding_fn: Callable[[list[str]], list[list[float]]], + ): + """Initializes the MilvusVectorStore with settings and embedding function. + + Args: + settings: The Milvus tool settings containing vector store configuration. + embedding_fn: A function that takes a list of texts and returns a list + of embedding vectors. Signature: ``(list[str]) -> list[list[float]]``. + For example, using Google GenAI:: + + from google.genai import Client + client = Client() + def embedding_fn(texts): + resp = client.models.embed_content( + model="text-embedding-004", contents=texts) + return [list(e.values) for e in resp.embeddings] + + Raises: + ValueError: If vector_store_settings is not set. + ImportError: If pymilvus is not installed. + """ + if not settings.vector_store_settings: + raise ValueError("Milvus vector store settings are not set.") + + if MilvusClient is None: + raise ImportError( + "pymilvus package not found. " + 'Please install with: pip install "google-adk[milvus]"' + ) + + self._settings = settings.vector_store_settings + self._embedding_fn = embedding_fn + + self._client = MilvusClient( + uri=self._settings.uri, + token=self._settings.token, + db_name=self._settings.db_name, + ) + + def setup(self) -> None: + """Creates the Milvus collection and index if they do not exist. + + The collection schema includes: + - A primary key field (auto-generated int64 id). + - A text content field (VARCHAR). + - A vector embedding field (FLOAT_VECTOR). + + The vector index is created with the configured metric type and + index type. + """ + if self._client.has_collection(self._settings.collection_name): + logger.info( + "Collection '%s' already exists, skipping setup.", + self._settings.collection_name, + ) + return + + schema = self._client.create_schema(auto_id=True, enable_dynamic_field=True) + schema.add_field( + field_name=self._settings.primary_field, + datatype=DataType.INT64, + is_primary=True, + ) + schema.add_field( + field_name=self._settings.content_field, + datatype=DataType.VARCHAR, + max_length=65535, + ) + schema.add_field( + field_name=self._settings.embedding_field, + datatype=DataType.FLOAT_VECTOR, + dim=self._settings.dimension, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name=self._settings.embedding_field, + index_type=self._settings.index_type, + metric_type=self._settings.metric_type, + ) + + self._client.create_collection( + collection_name=self._settings.collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + "Created collection '%s' with dimension=%d, metric=%s, index=%s.", + self._settings.collection_name, + self._settings.dimension, + self._settings.metric_type, + self._settings.index_type, + ) + + def add_contents( + self, + contents: list[str], + *, + additional_fields: Optional[list[dict]] = None, + batch_size: int = 200, + ) -> None: + """Adds text contents to the vector store. + + Performs batch embedding generation and insertion into the Milvus + collection. + + Args: + contents: An iterable of text contents to add. + additional_fields: Optional list of dicts with extra field values + for each content row. + batch_size: Maximum number of items per batch. Defaults to 200. + + Raises: + ValueError: If additional_fields length does not match contents length. + """ + if additional_fields and len(additional_fields) != len(contents): + raise ValueError( + "The number of additional_fields must match the number of contents." + ) + + total_rows = 0 + for i in range(0, len(contents), batch_size): + batch_contents = contents[i : i + batch_size] + batch_extra = ( + additional_fields[i : i + batch_size] if additional_fields else None + ) + + logger.debug( + "Embedding batch %d to %d (size: %d)...", + i, + i + len(batch_contents), + len(batch_contents), + ) + embeddings = self._embedding_fn(batch_contents) + + data = [] + for j, (content, embedding) in enumerate(zip(batch_contents, embeddings)): + row = { + self._settings.content_field: content, + self._settings.embedding_field: embedding, + } + if batch_extra and j < len(batch_extra): + row.update(batch_extra[j]) + data.append(row) + + self._client.insert( + collection_name=self._settings.collection_name, + data=data, + ) + total_rows += len(data) + + logger.info( + "Added %d contents to collection '%s'.", + total_rows, + self._settings.collection_name, + ) + + async def add_contents_async( + self, + contents: list[str], + *, + additional_fields: Optional[list[dict]] = None, + batch_size: int = 200, + ) -> None: + """Asynchronously adds text contents to the vector store. + + Args: + contents: An iterable of text contents to add. + additional_fields: Optional list of dicts with extra field values + for each content row. + batch_size: Maximum number of items per batch. Defaults to 200. + """ + await asyncio.to_thread( + self.add_contents, + contents, + additional_fields=additional_fields, + batch_size=batch_size, + ) + + def search( + self, + query: str, + *, + top_k: Optional[int] = None, + filter_expr: Optional[str] = None, + ) -> list[dict]: + """Performs vector similarity search. + + Args: + query: The search query text. It will be embedded using the + configured embedding function. + top_k: Number of results to return. Overrides the default from + settings if provided. + filter_expr: Optional Milvus filter expression for pre-filtering. + + Returns: + A list of dicts, each containing the matched content and any + configured output fields, along with the distance score. + """ + top_k = top_k or self._settings.top_k + + query_embedding = self._embedding_fn([query])[0] + + output_fields = self._settings.output_fields or [ + self._settings.content_field + ] + + search_params = {"metric_type": self._settings.metric_type} + + results = self._client.search( + collection_name=self._settings.collection_name, + data=[query_embedding], + limit=top_k, + output_fields=output_fields, + search_params=search_params, + filter=filter_expr or "", + ) + + if not results or not results[0]: + return [] + + return [ + {**hit["entity"], "distance": hit["distance"]} for hit in results[0] + ] + + def close(self) -> None: + """Closes the Milvus client connection.""" + self._client.close() diff --git a/src/google/adk/tools/milvus/search_tool.py b/src/google/adk/tools/milvus/search_tool.py new file mode 100644 index 0000000000..489f3e9aa8 --- /dev/null +++ b/src/google/adk/tools/milvus/search_tool.py @@ -0,0 +1,52 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus similarity search tool function for agent use.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from .milvus_vector_store import MilvusVectorStore + +logger = logging.getLogger("google_adk." + __name__) + + +def similarity_search( + query: str, + vector_store: MilvusVectorStore, + filter_expr: Optional[str] = None, +) -> str: + """Search for similar content in Milvus vector store. + + This function is intended to be wrapped by MilvusToolset and exposed + to agents as a tool. + + Args: + query: The search query text. + vector_store: The MilvusVectorStore instance (injected by MilvusToolset). + filter_expr: Optional Milvus filter expression for pre-filtering + results. + + Returns: + Search results formatted as a JSON string. + """ + results = vector_store.search(query=query, filter_expr=filter_expr) + + if not results: + return "No matching results found." + + return json.dumps(results, ensure_ascii=False, default=str) diff --git a/src/google/adk/tools/milvus/settings.py b/src/google/adk/tools/milvus/settings.py new file mode 100644 index 0000000000..7691bea7e2 --- /dev/null +++ b/src/google/adk/tools/milvus/settings.py @@ -0,0 +1,98 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Settings for Milvus vector store and toolset.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import model_validator + + +class MilvusVectorStoreSettings(BaseModel): + """Settings for Milvus vector store. + + This is used for vector similarity search in a Milvus collection. + Provide the collection and connection settings to use with the + ``similarity_search`` tool. + """ + + uri: str = "http://localhost:19530" + """The Milvus server URI. + + Can be a remote address like ``http://localhost:19530`` or a local + file path for Milvus Lite (e.g. ``./milvus.db``). + """ + + token: Optional[str] = None + """Optional authentication token (e.g. for Zilliz Cloud).""" + + db_name: str = "default" + """The Milvus database name.""" + + collection_name: str + """Required. The name of the Milvus collection.""" + + dimension: int = 768 + """The dimension of the embedding vectors.""" + + metric_type: str = "COSINE" + """The distance metric for similarity search. + + Supported values: ``COSINE``, ``L2``, ``IP``. + """ + + index_type: str = "AUTOINDEX" + """The index type for the vector field. + + Supported values: ``AUTOINDEX``, ``IVF_FLAT``, ``HNSW``, etc. + """ + + content_field: str = "content" + """The name of the text content field in the collection.""" + + embedding_field: str = "embedding" + """The name of the vector embedding field in the collection.""" + + primary_field: str = "id" + """The name of the primary key field in the collection.""" + + top_k: int = 5 + """The default number of results to return from similarity search.""" + + output_fields: Optional[list[str]] = None + """Optional additional fields to return in search results. + + If ``None``, only the ``content_field`` is returned. + """ + + @model_validator(mode="after") + def _validate_settings(self): + """Validate the vector store settings.""" + if self.dimension <= 0: + raise ValueError( + f"Invalid dimension: {self.dimension}. Must be positive." + ) + if self.top_k <= 0: + raise ValueError(f"Invalid top_k: {self.top_k}. Must be positive.") + return self + + +class MilvusToolSettings(BaseModel): + """Settings for Milvus toolset.""" + + vector_store_settings: Optional[MilvusVectorStoreSettings] = None + """Settings for Milvus vector store and vector similarity search.""" diff --git a/tests/unittests/tools/milvus/__init__.py b/tests/unittests/tools/milvus/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/tests/unittests/tools/milvus/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/milvus/test_milvus_toolset.py b/tests/unittests/tools/milvus/test_milvus_toolset.py new file mode 100644 index 0000000000..b7ef5c540a --- /dev/null +++ b/tests/unittests/tools/milvus/test_milvus_toolset.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.milvus.milvus_tool import MilvusTool +from google.adk.tools.milvus.milvus_toolset import MilvusToolset +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + +DIMENSION = 4 + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + return [[0.1] * DIMENSION for _ in texts] + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_returns_search_tool(mock_milvus_client_cls): + """Test that MilvusToolset returns a similarity_search tool.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + tools = await toolset.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], MilvusTool) + assert tools[0].name == "similarity_search" + + +@pytest.mark.asyncio +async def test_toolset_no_vector_store_settings(): + """Test that MilvusToolset returns no tools without vector store settings.""" + toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings(), + embedding_fn=_mock_embedding_fn, + ) + + tools = await toolset.get_tools() + assert len(tools) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_with_filter(mock_milvus_client_cls): + """Test that MilvusToolset respects tool_filter.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + tool_filter=["nonexistent_tool"], + ) + + tools = await toolset.get_tools() + assert len(tools) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_with_matching_filter(mock_milvus_client_cls): + """Test that tool_filter includes matching tools.""" + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + tool_filter=["similarity_search"], + ) + + tools = await toolset.get_tools() + assert len(tools) == 1 + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_close(mock_milvus_client_cls): + """Test that close shuts down the vector store.""" + mock_client = mock_milvus_client_cls.return_value + settings = MilvusToolSettings( + vector_store_settings=MilvusVectorStoreSettings( + collection_name="test", + dimension=DIMENSION, + ), + ) + toolset = MilvusToolset( + milvus_tool_settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + # Force vector store creation + await toolset.get_tools() + + await toolset.close() + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_toolset_close_without_init(mock_milvus_client_cls): + """Test that close works even if vector store was never created.""" + toolset = MilvusToolset( + milvus_tool_settings=MilvusToolSettings(), + embedding_fn=_mock_embedding_fn, + ) + # Should not raise + await toolset.close() diff --git a/tests/unittests/tools/milvus/test_milvus_vector_store.py b/tests/unittests/tools/milvus/test_milvus_vector_store.py new file mode 100644 index 0000000000..7efee5a7be --- /dev/null +++ b/tests/unittests/tools/milvus/test_milvus_vector_store.py @@ -0,0 +1,285 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + +DIMENSION = 4 + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + """A mock embedding function returning fixed-dimension vectors.""" + return [[0.1] * DIMENSION for _ in texts] + + +@pytest.fixture +def vector_store_settings(): + return MilvusVectorStoreSettings( + collection_name="test_collection", + dimension=DIMENSION, + metric_type="COSINE", + ) + + +@pytest.fixture +def tool_settings(vector_store_settings): + return MilvusToolSettings(vector_store_settings=vector_store_settings) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_init_success(mock_milvus_client_cls, tool_settings): + """Test successful initialization of MilvusVectorStore.""" + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + mock_milvus_client_cls.assert_called_once_with( + uri="http://localhost:19530", + token=None, + db_name="default", + ) + assert store._embedding_fn is _mock_embedding_fn + + +def test_init_missing_vector_store_settings(): + """Test that missing vector_store_settings raises ValueError.""" + settings = MilvusToolSettings() + with pytest.raises(ValueError, match="not set"): + MilvusVectorStore( + settings=settings, + embedding_fn=_mock_embedding_fn, + ) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_setup_creates_collection(mock_milvus_client_cls, tool_settings): + """Test that setup creates collection when it doesn't exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.setup() + + mock_client.has_collection.assert_called_once_with("test_collection") + mock_client.create_collection.assert_called_once() + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_setup_skips_existing_collection(mock_milvus_client_cls, tool_settings): + """Test that setup skips when collection already exists.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + store.setup() + + mock_client.has_collection.assert_called_once_with("test_collection") + mock_client.create_collection.assert_not_called() + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents(mock_milvus_client_cls, tool_settings): + """Test adding contents to the vector store.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = ["hello world", "foo bar"] + store.add_contents(contents) + + mock_client.insert.assert_called_once() + call_args = mock_client.insert.call_args + assert call_args.kwargs["collection_name"] == "test_collection" + data = call_args.kwargs["data"] + assert len(data) == 2 + assert data[0]["content"] == "hello world" + assert data[1]["content"] == "foo bar" + assert len(data[0]["embedding"]) == DIMENSION + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_with_batching(mock_milvus_client_cls, tool_settings): + """Test that add_contents batches correctly.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = [f"doc_{i}" for i in range(5)] + store.add_contents(contents, batch_size=2) + + # 5 items with batch_size=2 => 3 batches (2, 2, 1) + assert mock_client.insert.call_count == 3 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_with_additional_fields( + mock_milvus_client_cls, tool_settings +): + """Test adding contents with additional fields.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + contents = ["hello"] + additional = [{"title": "greeting"}] + store.add_contents(contents, additional_fields=additional) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert data[0]["title"] == "greeting" + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_add_contents_mismatched_additional_fields( + mock_milvus_client_cls, tool_settings +): + """Test that mismatched additional_fields raises ValueError.""" + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + with pytest.raises(ValueError, match="must match"): + store.add_contents( + ["doc1", "doc2"], + additional_fields=[{"title": "only_one"}], + ) + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search(mock_milvus_client_cls, tool_settings): + """Test similarity search.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[ + {"entity": {"content": "result1"}, "distance": 0.9}, + {"entity": {"content": "result2"}, "distance": 0.8}, + ]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + results = store.search("test query") + + mock_client.search.assert_called_once() + call_args = mock_client.search.call_args + assert call_args.kwargs["collection_name"] == "test_collection" + assert call_args.kwargs["limit"] == 5 + assert call_args.kwargs["output_fields"] == ["content"] + + assert len(results) == 2 + assert results[0]["content"] == "result1" + assert results[0]["distance"] == 0.9 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_with_custom_top_k(mock_milvus_client_cls, tool_settings): + """Test search with custom top_k.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.search("test", top_k=10) + + call_args = mock_client.search.call_args + assert call_args.kwargs["limit"] == 10 + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_with_filter(mock_milvus_client_cls, tool_settings): + """Test search with filter expression.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + store.search("test", filter_expr='category == "tech"') + + call_args = mock_client.search.call_args + assert call_args.kwargs["filter"] == 'category == "tech"' + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_search_empty_results(mock_milvus_client_cls, tool_settings): + """Test search returning empty results.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.search.return_value = [[]] + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + results = store.search("test") + assert results == [] + + +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +def test_close(mock_milvus_client_cls, tool_settings): + """Test closing the vector store.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + store.close() + + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.tools.milvus.milvus_vector_store.MilvusClient") +async def test_add_contents_async(mock_milvus_client_cls, tool_settings): + """Test async add_contents delegates to sync version.""" + mock_client = mock_milvus_client_cls.return_value + + store = MilvusVectorStore( + settings=tool_settings, + embedding_fn=_mock_embedding_fn, + ) + + await store.add_contents_async(["hello", "world"]) + + mock_client.insert.assert_called_once() diff --git a/tests/unittests/tools/milvus/test_search_tool.py b/tests/unittests/tools/milvus/test_search_tool.py new file mode 100644 index 0000000000..97b4d3bd36 --- /dev/null +++ b/tests/unittests/tools/milvus/test_search_tool.py @@ -0,0 +1,76 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from unittest import mock + +from google.adk.tools.milvus.milvus_vector_store import MilvusVectorStore +from google.adk.tools.milvus.search_tool import similarity_search + + +def test_similarity_search_with_results(): + """Test similarity_search returns JSON formatted results.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [ + {"content": "result1", "distance": 0.9}, + {"content": "result2", "distance": 0.8}, + ] + + result = similarity_search( + query="test query", + vector_store=mock_store, + ) + + mock_store.search.assert_called_once_with( + query="test query", filter_expr=None + ) + parsed = json.loads(result) + assert len(parsed) == 2 + assert parsed[0]["content"] == "result1" + + +def test_similarity_search_no_results(): + """Test similarity_search with no matching results.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [] + + result = similarity_search( + query="unknown query", + vector_store=mock_store, + ) + + assert result == "No matching results found." + + +def test_similarity_search_with_filter(): + """Test similarity_search with filter expression.""" + mock_store = mock.create_autospec(MilvusVectorStore, instance=True) + mock_store.search.return_value = [ + {"content": "filtered result", "distance": 0.95}, + ] + + result = similarity_search( + query="test", + vector_store=mock_store, + filter_expr='category == "tech"', + ) + + mock_store.search.assert_called_once_with( + query="test", filter_expr='category == "tech"' + ) + parsed = json.loads(result) + assert len(parsed) == 1 + assert parsed[0]["content"] == "filtered result" diff --git a/tests/unittests/tools/milvus/test_settings.py b/tests/unittests/tools/milvus/test_settings.py new file mode 100644 index 0000000000..e0050549e1 --- /dev/null +++ b/tests/unittests/tools/milvus/test_settings.py @@ -0,0 +1,105 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.milvus.settings import MilvusToolSettings +from google.adk.tools.milvus.settings import MilvusVectorStoreSettings +import pytest + + +def test_default_settings(): + """Test MilvusVectorStoreSettings with default values.""" + settings = MilvusVectorStoreSettings(collection_name="test_collection") + assert settings.uri == "http://localhost:19530" + assert settings.token is None + assert settings.db_name == "default" + assert settings.collection_name == "test_collection" + assert settings.dimension == 768 + assert settings.metric_type == "COSINE" + assert settings.index_type == "AUTOINDEX" + assert settings.content_field == "content" + assert settings.embedding_field == "embedding" + assert settings.primary_field == "id" + assert settings.top_k == 5 + assert settings.output_fields is None + + +def test_custom_settings(): + """Test MilvusVectorStoreSettings with custom values.""" + settings = MilvusVectorStoreSettings( + uri="http://milvus:19530", + token="test_token", + db_name="mydb", + collection_name="docs", + dimension=384, + metric_type="L2", + index_type="HNSW", + content_field="text", + embedding_field="vec", + primary_field="pk", + top_k=10, + output_fields=["text", "title"], + ) + assert settings.uri == "http://milvus:19530" + assert settings.token == "test_token" + assert settings.db_name == "mydb" + assert settings.dimension == 384 + assert settings.metric_type == "L2" + assert settings.index_type == "HNSW" + assert settings.content_field == "text" + assert settings.embedding_field == "vec" + assert settings.top_k == 10 + assert settings.output_fields == ["text", "title"] + + +def test_invalid_dimension(): + """Test that invalid dimension raises ValueError.""" + with pytest.raises(ValueError, match="Invalid dimension"): + MilvusVectorStoreSettings( + collection_name="test", + dimension=0, + ) + + +def test_invalid_dimension_negative(): + """Test that negative dimension raises ValueError.""" + with pytest.raises(ValueError, match="Invalid dimension"): + MilvusVectorStoreSettings( + collection_name="test", + dimension=-1, + ) + + +def test_invalid_top_k(): + """Test that invalid top_k raises ValueError.""" + with pytest.raises(ValueError, match="Invalid top_k"): + MilvusVectorStoreSettings( + collection_name="test", + top_k=0, + ) + + +def test_milvus_tool_settings_default(): + """Test MilvusToolSettings with default values.""" + settings = MilvusToolSettings() + assert settings.vector_store_settings is None + + +def test_milvus_tool_settings_with_vector_store(): + """Test MilvusToolSettings with vector store settings.""" + vs_settings = MilvusVectorStoreSettings(collection_name="test") + settings = MilvusToolSettings(vector_store_settings=vs_settings) + assert settings.vector_store_settings is not None + assert settings.vector_store_settings.collection_name == "test" From ea8d187774c4757964e2b1adc27aba1b50c2ac43 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Mon, 9 Feb 2026 05:47:00 +0000 Subject: [PATCH 2/3] docs: improve similarity_search tool description for LLM Enhance the docstring of the similarity_search function to provide a richer tool description for the LLM. The previous description was too brief ("Search for similar content in Milvus vector store."). The new description explains when to use the tool, how to write effective queries, filter expression syntax, and return format. --- src/google/adk/tools/milvus/search_tool.py | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/google/adk/tools/milvus/search_tool.py b/src/google/adk/tools/milvus/search_tool.py index 489f3e9aa8..08c1d511b7 100644 --- a/src/google/adk/tools/milvus/search_tool.py +++ b/src/google/adk/tools/milvus/search_tool.py @@ -30,20 +30,34 @@ def similarity_search( vector_store: MilvusVectorStore, filter_expr: Optional[str] = None, ) -> str: - """Search for similar content in Milvus vector store. + # fmt: off + """Search the knowledge base for information relevant to the user's query. - This function is intended to be wrapped by MilvusToolset and exposed - to agents as a tool. + Use this tool when you need to look up facts, find relevant documents, + or answer questions that require knowledge from the vector database. + The tool performs semantic similarity search — it finds content whose + meaning is closest to the query, not just keyword matches. Args: - query: The search query text. - vector_store: The MilvusVectorStore instance (injected by MilvusToolset). - filter_expr: Optional Milvus filter expression for pre-filtering - results. + query (str): A natural-language search query describing what + information you are looking for. Be specific and descriptive + for better results. For example, use "How does Milvus handle + vector indexing?" rather than just "Milvus". + filter_expr (str): An optional filter expression to narrow down + search results before ranking by similarity. Uses Milvus + boolean expression syntax, for example: + ``category == "tech"`` or ``year > 2023``. Leave empty if no + filtering is needed. Returns: - Search results formatted as a JSON string. + str: A JSON-formatted string containing the search results. Each + result includes the matched content and a distance score + indicating similarity (lower distance means higher similarity + for L2/EUCLIDEAN, higher score means higher similarity for + COSINE/IP). Returns "No matching results found." if no + relevant content exists. """ + # fmt: on results = vector_store.search(query=query, filter_expr=filter_expr) if not results: From 780cbbc7ef1605eff3c8a6c0dceae9e67a8060d4 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Tue, 10 Feb 2026 02:05:22 +0000 Subject: [PATCH 3/3] feat: Add MilvusMemoryService for cross-session memory Implement BaseMemoryService backed by Milvus vector database, enabling semantic search across past conversation history. - MilvusMemoryService: stores session events as vector-embedded text, with app_name/user_id scoping and deduplication - Lazy collection setup (auto-creates on first use) - 13 unit tests with mocked MilvusClient - E2E verified: cross-session recall and user isolation work with real Google GenAI embedding + Gemini LLM --- src/google/adk/features/_feature_registry.py | 4 + src/google/adk/memory/__init__.py | 11 + .../adk/memory/milvus_memory_service.py | 349 ++++++++++++++++ .../memory/test_milvus_memory_service.py | 389 ++++++++++++++++++ 4 files changed, 753 insertions(+) create mode 100644 src/google/adk/memory/milvus_memory_service.py create mode 100644 tests/unittests/memory/test_milvus_memory_service.py diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 5021e321eb..1341355501 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -38,6 +38,7 @@ class FeatureName(str, Enum): GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG" GOOGLE_TOOL = "GOOGLE_TOOL" JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL" + MILVUS_MEMORY_SERVICE = "MILVUS_MEMORY_SERVICE" MILVUS_TOOLSET = "MILVUS_TOOLSET" MILVUS_VECTOR_STORE = "MILVUS_VECTOR_STORE" PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING" @@ -116,6 +117,9 @@ class FeatureConfig: FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig( FeatureStage.WIP, default_on=False ), + FeatureName.MILVUS_MEMORY_SERVICE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.MILVUS_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..c78a3ced5c 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -35,3 +35,14 @@ ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) + +try: + from .milvus_memory_service import MilvusMemoryService + + __all__.append('MilvusMemoryService') +except ImportError: + logger.debug( + 'pymilvus is not installed. If you want to use the' + ' MilvusMemoryService please install it with:' + ' pip install "google-adk[milvus]". If not, you can ignore this warning.' + ) diff --git a/src/google/adk/memory/milvus_memory_service.py b/src/google/adk/memory/milvus_memory_service.py new file mode 100644 index 0000000000..39cbbd8db3 --- /dev/null +++ b/src/google/adk/memory/milvus_memory_service.py @@ -0,0 +1,349 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Milvus-backed memory service for cross-session conversation memory.""" + +from __future__ import annotations + +import logging +from typing import Callable +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types +from typing_extensions import override + +from ..features import experimental +from ..features import FeatureName +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +try: + from pymilvus import DataType + from pymilvus import MilvusClient +except ImportError: + MilvusClient = None + DataType = None + + +def _should_skip_event(content: types.Content) -> bool: + """Returns True if the event has no user-readable content.""" + if not content or not content.parts: + return True + for part in content.parts: + if part.text or part.inline_data or part.file_data: + return False + return True + + +def _extract_event_text(content: types.Content) -> str: + """Extracts and joins all text parts from event content.""" + if not content or not content.parts: + return "" + return " ".join(part.text for part in content.parts if part.text) + + +@experimental(FeatureName.MILVUS_MEMORY_SERVICE) +class MilvusMemoryService(BaseMemoryService): + """Memory service backed by Milvus vector database. + + Stores session events as vector-embedded text in a Milvus collection, + enabling semantic search across past conversations. + + Supports all Milvus deployment modes: + - Milvus Lite (local file path, e.g., ``"./memory.db"``) + - Milvus Server (e.g., ``"http://localhost:19530"``) + - Zilliz Cloud (e.g., ``"https://in01-xxx.cloud.zilliz.com"`` + token) + """ + + def __init__( + self, + *, + embedding_fn: Callable[[list[str]], list[list[float]]], + collection_name: str = "adk_memory", + uri: str = "http://localhost:19530", + token: Optional[str] = None, + db_name: str = "default", + dimension: int = 768, + metric_type: str = "COSINE", + top_k: int = 10, + ): + """Initializes the MilvusMemoryService. + + Args: + embedding_fn: A function that takes a list of texts and returns a + list of embedding vectors. + collection_name: The Milvus collection name for storing memories. + uri: The Milvus server URI or local file path. + token: Optional authentication token (e.g., for Zilliz Cloud). + db_name: The Milvus database name. + dimension: The dimension of the embedding vectors. + metric_type: The distance metric (COSINE, L2, or IP). + top_k: Default number of results for memory search. + + Raises: + ImportError: If pymilvus is not installed. + """ + if MilvusClient is None: + raise ImportError( + "pymilvus package not found. " + 'Please install with: pip install "google-adk[milvus]"' + ) + + self._embedding_fn = embedding_fn + self._collection_name = collection_name + self._dimension = dimension + self._metric_type = metric_type + self._top_k = top_k + self._collection_ready = False + + self._client = MilvusClient( + uri=uri, + token=token, + db_name=db_name, + ) + + def _ensure_collection(self) -> None: + """Creates the collection and index if they do not exist (idempotent).""" + if self._collection_ready: + return + + if self._client.has_collection(self._collection_name): + logger.info( + "Memory collection '%s' already exists.", + self._collection_name, + ) + self._collection_ready = True + return + + schema = self._client.create_schema( + auto_id=True, enable_dynamic_field=True + ) + schema.add_field( + field_name="id", + datatype=DataType.INT64, + is_primary=True, + ) + schema.add_field( + field_name="app_name", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="user_id", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="session_id", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="author", + datatype=DataType.VARCHAR, + max_length=256, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + ) + schema.add_field( + field_name="timestamp", + datatype=DataType.DOUBLE, + ) + schema.add_field( + field_name="embedding", + datatype=DataType.FLOAT_VECTOR, + dim=self._dimension, + ) + + index_params = self._client.prepare_index_params() + index_params.add_index( + field_name="embedding", + index_type="AUTOINDEX", + metric_type=self._metric_type, + ) + + self._client.create_collection( + collection_name=self._collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + "Created memory collection '%s' (dim=%d, metric=%s).", + self._collection_name, + self._dimension, + self._metric_type, + ) + self._collection_ready = True + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Adds session events to Milvus as vector-embedded memories. + + Filters out events with no user-readable content (e.g., function + calls). Deduplicates by checking existing events for the same + session before inserting. + + Args: + session: The session whose events will be stored. + """ + self._ensure_collection() + + # Collect events with meaningful text content. + events_to_store = [] + for event in session.events: + if _should_skip_event(event.content): + continue + text = _extract_event_text(event.content) + if not text.strip(): + continue + events_to_store.append((event, text)) + + if not events_to_store: + logger.info("No events to add to memory for session %s.", session.id) + return + + # Deduplicate: find which timestamps already exist for this session. + existing_timestamps = set() + try: + existing = self._client.query( + collection_name=self._collection_name, + filter=( + f'app_name == "{session.app_name}"' + f' and user_id == "{session.user_id}"' + f' and session_id == "{session.id}"' + ), + output_fields=["timestamp"], + ) + existing_timestamps = {row["timestamp"] for row in existing} + except Exception: + logger.debug( + "Could not query existing events for dedup, inserting all.", + exc_info=True, + ) + + new_events = [ + (event, text) + for event, text in events_to_store + if event.timestamp not in existing_timestamps + ] + + if not new_events: + logger.info( + "All events for session %s already in memory.", session.id + ) + return + + # Embed and insert. + texts = [text for _, text in new_events] + embeddings = self._embedding_fn(texts) + + data = [] + for (event, text), embedding in zip(new_events, embeddings): + data.append({ + "app_name": session.app_name, + "user_id": session.user_id, + "session_id": session.id, + "author": event.author or "", + "content": text, + "timestamp": event.timestamp, + "embedding": embedding, + }) + + self._client.insert( + collection_name=self._collection_name, + data=data, + ) + + logger.info( + "Added %d events from session %s to memory collection '%s'.", + len(data), + session.id, + self._collection_name, + ) + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Searches memories by semantic similarity, scoped to the user. + + Args: + app_name: The application name. + user_id: The user ID. + query: The natural-language search query. + + Returns: + A SearchMemoryResponse with matching MemoryEntry objects. + """ + self._ensure_collection() + + query_embedding = self._embedding_fn([query])[0] + + filter_expr = ( + f'app_name == "{app_name}" and user_id == "{user_id}"' + ) + + results = self._client.search( + collection_name=self._collection_name, + data=[query_embedding], + limit=self._top_k, + output_fields=["content", "author", "timestamp"], + search_params={"metric_type": self._metric_type}, + filter=filter_expr, + ) + + if not results or not results[0]: + return SearchMemoryResponse() + + memories = [] + for hit in results[0]: + entity = hit["entity"] + author = entity.get("author", "") + content_text = entity.get("content", "") + timestamp = entity.get("timestamp") + + role = "user" if author == "user" else "model" + memory_entry = MemoryEntry( + content=types.Content( + parts=[types.Part(text=content_text)], + role=role, + ), + author=author, + timestamp=( + _utils.format_timestamp(timestamp) if timestamp else None + ), + ) + memories.append(memory_entry) + + return SearchMemoryResponse(memories=memories) + + def close(self) -> None: + """Closes the Milvus client connection.""" + self._client.close() diff --git a/tests/unittests/memory/test_milvus_memory_service.py b/tests/unittests/memory/test_milvus_memory_service.py new file mode 100644 index 0000000000..3fb84e8693 --- /dev/null +++ b/tests/unittests/memory/test_milvus_memory_service.py @@ -0,0 +1,389 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.events.event import Event +from google.adk.memory.milvus_memory_service import MilvusMemoryService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +DIMENSION = 4 +APP_NAME = "test-app" +USER_ID = "test-user" + + +def _mock_embedding_fn(texts: list[str]) -> list[list[float]]: + """A mock embedding function returning fixed-dimension vectors.""" + return [[0.1] * DIMENSION for _ in texts] + + +def _make_session( + session_id: str = "session-1", + app_name: str = APP_NAME, + user_id: str = USER_ID, + events: list[Event] | None = None, +) -> Session: + if events is None: + events = [ + Event( + id="event-1", + invocation_id="inv-1", + author="user", + timestamp=1000.0, + content=types.Content( + parts=[types.Part(text="Hello, I like Python.")] + ), + ), + Event( + id="event-2", + invocation_id="inv-2", + author="model", + timestamp=1001.0, + content=types.Content( + parts=[types.Part(text="Python is a great language!")] + ), + ), + ] + return Session( + app_name=app_name, + user_id=user_id, + id=session_id, + last_update_time=2000, + events=events, + ) + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_init_success(mock_milvus_client_cls): + """Test successful initialization.""" + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + uri="http://localhost:19530", + collection_name="test_memory", + dimension=DIMENSION, + ) + mock_milvus_client_cls.assert_called_once_with( + uri="http://localhost:19530", + token=None, + db_name="default", + ) + assert service._embedding_fn is _mock_embedding_fn + assert service._collection_name == "test_memory" + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_creates_when_missing(mock_milvus_client_cls): + """Test that _ensure_collection creates collection when it doesn't exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + + mock_client.has_collection.assert_called_once_with("adk_memory") + mock_client.create_collection.assert_called_once() + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_skips_existing(mock_milvus_client_cls): + """Test that _ensure_collection skips if collection exists.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + + mock_client.has_collection.assert_called_once_with("adk_memory") + mock_client.create_collection.assert_not_called() + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_ensure_collection_idempotent(mock_milvus_client_cls): + """Test that _ensure_collection only checks once.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = False + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service._ensure_collection() + service._ensure_collection() + + # has_collection should only be called once due to _collection_ready flag. + mock_client.has_collection.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_to_memory(mock_milvus_client_cls): + """Test adding a session with events to memory.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [] # No existing events. + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + mock_client.insert.assert_called_once() + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert len(data) == 2 + assert data[0]["app_name"] == APP_NAME + assert data[0]["user_id"] == USER_ID + assert data[0]["session_id"] == "session-1" + assert data[0]["author"] == "user" + assert data[0]["content"] == "Hello, I like Python." + assert data[0]["timestamp"] == 1000.0 + assert len(data[0]["embedding"]) == DIMENSION + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_skips_empty_events(mock_milvus_client_cls): + """Test that events without text content are skipped.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [] + + session = _make_session(events=[ + Event( + id="event-empty", + invocation_id="inv-1", + author="user", + timestamp=1000.0, + # No content at all. + ), + Event( + id="event-func", + invocation_id="inv-2", + author="model", + timestamp=1001.0, + content=types.Content( + parts=[types.Part(function_call=types.FunctionCall( + name="test_fn", args={} + ))] + ), + ), + Event( + id="event-text", + invocation_id="inv-3", + author="user", + timestamp=1002.0, + content=types.Content( + parts=[types.Part(text="Real message")] + ), + ), + ]) + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + await service.add_session_to_memory(session) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + assert len(data) == 1 + assert data[0]["content"] == "Real message" + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_deduplication(mock_milvus_client_cls): + """Test that events already in Milvus are not re-inserted.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + # Simulate one event already existing. + mock_client.query.return_value = [{"timestamp": 1000.0}] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + call_args = mock_client.insert.call_args + data = call_args.kwargs["data"] + # Only the second event (timestamp=1001.0) should be inserted. + assert len(data) == 1 + assert data[0]["timestamp"] == 1001.0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_all_duplicated(mock_milvus_client_cls): + """Test that no insert happens when all events already exist.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.query.return_value = [ + {"timestamp": 1000.0}, + {"timestamp": 1001.0}, + ] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session() + + await service.add_session_to_memory(session) + + mock_client.insert.assert_not_called() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory(mock_milvus_client_cls): + """Test similarity search returns MemoryEntry objects.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[ + { + "entity": { + "content": "I like Python.", + "author": "user", + "timestamp": 1000.0, + }, + "distance": 0.95, + }, + { + "entity": { + "content": "Python is great!", + "author": "model", + "timestamp": 1001.0, + }, + "distance": 0.90, + }, + ]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + response = await service.search_memory( + app_name=APP_NAME, + user_id=USER_ID, + query="Python", + ) + + assert len(response.memories) == 2 + assert response.memories[0].author == "user" + assert response.memories[0].content.parts[0].text == "I like Python." + assert response.memories[0].content.role == "user" + assert response.memories[1].author == "model" + assert response.memories[1].content.role == "model" + + # Verify search was called with correct filter. + call_args = mock_client.search.call_args + assert f'app_name == "{APP_NAME}"' in call_args.kwargs["filter"] + assert f'user_id == "{USER_ID}"' in call_args.kwargs["filter"] + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory_empty_results(mock_milvus_client_cls): + """Test search returning no results.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + response = await service.search_memory( + app_name=APP_NAME, + user_id=USER_ID, + query="something", + ) + + assert len(response.memories) == 0 + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_search_memory_user_scoping(mock_milvus_client_cls): + """Test that search scopes by app_name and user_id.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + mock_client.search.return_value = [[]] + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + + await service.search_memory( + app_name="app-A", + user_id="user-B", + query="test", + ) + + call_args = mock_client.search.call_args + filter_expr = call_args.kwargs["filter"] + assert 'app_name == "app-A"' in filter_expr + assert 'user_id == "user-B"' in filter_expr + + +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +def test_close(mock_milvus_client_cls): + """Test closing the memory service.""" + mock_client = mock_milvus_client_cls.return_value + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + service.close() + + mock_client.close.assert_called_once() + + +@pytest.mark.asyncio +@mock.patch("google.adk.memory.milvus_memory_service.MilvusClient") +async def test_add_session_no_events(mock_milvus_client_cls): + """Test adding a session with no events does nothing.""" + mock_client = mock_milvus_client_cls.return_value + mock_client.has_collection.return_value = True + + service = MilvusMemoryService( + embedding_fn=_mock_embedding_fn, + dimension=DIMENSION, + ) + session = _make_session(events=[]) + + await service.add_session_to_memory(session) + + mock_client.insert.assert_not_called()