From f5788abd6071eea692a6a7ee5b25a3d1b9687203 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Wed, 5 Nov 2025 06:22:54 +0000 Subject: [PATCH 1/6] sdks/python: add milvus sink integration --- .../ml/rag/ingestion/milvus_search.py | 359 ++++++++++ .../ml/rag/ingestion/milvus_search_it_test.py | 642 ++++++++++++++++++ .../ml/rag/ingestion/milvus_search_test.py | 123 ++++ .../ml/rag/ingestion/postgres_common.py | 90 ++- 4 files changed, 1186 insertions(+), 28 deletions(-) create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py new file mode 100644 index 000000000000..e019a03d7514 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py @@ -0,0 +1,359 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional + +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import MilvusHelpers +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.transforms import DoFn + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MilvusWriteConfig: + """Configuration parameters for writing data to Milvus collections. + + This class defines the parameters needed to write data to a Milvus collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + collection_name: Name of the target Milvus collection to write data to. + Must be a non-empty string. + partition_name: Name of the specific partition within the collection to + write to. If empty, writes to the default partition. + timeout: Maximum time in seconds to wait for write operations to complete. + If None, uses the client's default timeout. + write_config: Configuration for write operations including batch size and + other write-specific settings. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Milvus client parameters. + """ + collection_name: str + partition_name: str = "" + timeout: Optional[float] = None + write_config: WriteConfig = field(default_factory=WriteConfig) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + @property + def write_batch_size(self): + """Returns the batch size for write operations. + + Returns: + The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. + """ + return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE + + +@dataclass +class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vector data to Milvus collections. + + This class extends VectorDatabaseWriteConfig to provide Milvus-specific + configuration for ingesting vector embeddings and associated metadata. + It defines how Apache Beam chunks are converted to Milvus records and + handles the write operation parameters. + + The configuration includes connection parameters, write settings, and + column specifications that determine how chunk data is mapped to Milvus + fields. + + Args: + connection_params: Configuration for connecting to the Milvus server, + including URI, credentials, and connection options. + write_config: Configuration for write operations including collection name, + partition, batch size, and timeouts. + column_specs: List of column specifications defining how chunk fields are + mapped to Milvus collection fields. Defaults to standard RAG fields + (id, embedding, sparse_embedding, content, metadata). + + Example: + config = MilvusVectorWriterConfig( + connection_params=MilvusConnectionParameters( + uri="http://localhost:19530"), + write_config=MilvusWriteConfig(collection_name="my_collection"), + column_specs=MilvusVectorWriterConfig.default_column_specs()) + """ + connection_params: MilvusConnectionParameters + write_config: MilvusWriteConfig + column_specs: List[ColumnSpec] = field( + default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) + + def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: + """Creates a function to convert Apache Beam Chunks to Milvus records. + + Returns: + A function that takes a Chunk and returns a dictionary representing + a Milvus record with fields mapped according to column_specs. + """ + def convert(chunk: Chunk) -> Dict[str, Any]: + result = {} + for col in self.column_specs: + result[col.column_name] = col.value_fn(chunk) + return result + + return convert + + def create_write_transform(self) -> beam.PTransform: + """Creates the Apache Beam transform for writing to Milvus. + + Returns: + A PTransform that can be applied to a PCollection of Chunks to write + them to the configured Milvus collection. + """ + return _WriteToMilvusVectorDatabase(self) + + @staticmethod + def default_column_specs() -> List[ColumnSpec]: + """Returns default column specifications for RAG use cases. + + Creates column mappings for standard RAG fields: id, dense embedding, + sparse embedding, content text, and metadata. These specifications + define how Chunk fields are converted to Milvus-compatible formats. + + Returns: + List of ColumnSpec objects defining the default field mappings. + """ + column_specs = ColumnSpecsBuilder() + return column_specs\ + .with_id_spec()\ + .with_embedding_spec(convert_fn=lambda values: list(values))\ + .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ + .with_content_spec()\ + .with_metadata_spec(convert_fn=lambda values: dict(values))\ + .build() + + +class _WriteToMilvusVectorDatabase(beam.PTransform): + """Apache Beam PTransform for writing vector data to Milvus. + + This transform handles the conversion of Apache Beam Chunks to Milvus records + and coordinates the write operations. It applies the configured converter + function and uses a DoFn for batched writes to optimize performance. + + Args: + config: MilvusVectorWriterConfig containing all necessary parameters for + the write operation. + """ + def __init__(self, config: MilvusVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expands the PTransform to convert chunks and write to Milvus. + + Args: + pcoll: PCollection of Chunk objects to write to Milvus. + + Returns: + PCollection of the same Chunk objects after writing to Milvus. + """ + return ( + pcoll + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo( + _WriteMilvusFn( + self.config.connection_params, self.config.write_config))) + + +class _WriteMilvusFn(DoFn): + """DoFn that handles batched writes to Milvus. + + This DoFn accumulates records in batches and flushes them to Milvus when + the batch size is reached or when the bundle finishes. This approach + optimizes performance by reducing the number of individual write operations. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including batch size + and collection details. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self.batch = [] + + def process(self, element, *args, **kwargs): + """Processes individual records, batching them for efficient writes. + + Args: + element: A dictionary representing a Milvus record to write. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Yields: + The original element after adding it to the batch. + """ + _ = args, kwargs # Unused parameters + self.batch.append(element) + if len(self.batch) >= self._write_config.write_batch_size: + self._flush() + yield element + + def finish_bundle(self): + """Called when a bundle finishes processing. + + Flushes any remaining records in the batch to ensure all data is written. + """ + self._flush() + + def _flush(self): + """Flushes the current batch of records to Milvus. + + Creates a MilvusSink connection and writes all batched records, + then clears the batch for the next set of records. + """ + if len(self.batch) == 0: + return + with _MilvusSink(self._connection_params, self._write_config) as sink: + sink.write(self.batch) + self.batch = [] + + def display_data(self): + """Returns display data for monitoring and debugging. + + Returns: + Dictionary containing database, collection, and batch size information + for display in the Apache Beam monitoring UI. + """ + res = super().display_data() + res["database"] = self._connection_params.db_name + res["collection"] = self._write_config.collection_name + res["batch_size"] = self._write_config.write_batch_size + return res + + +class _MilvusSink: + """Low-level sink for writing data directly to Milvus. + + This class handles the direct interaction with the Milvus client for + upsert operations. It manages the connection lifecycle and provides + context manager support for proper resource cleanup. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including collection + and partition targeting. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self._client = None + + def write(self, documents): + """Writes a batch of documents to the Milvus collection. + + Performs an upsert operation to insert new documents or update existing + ones based on primary key. After the upsert, flushes the collection to + ensure data persistence. + + Args: + documents: List of dictionaries representing Milvus records to write. + Each dictionary should contain fields matching the collection schema. + """ + if not self._client: + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) + + try: + resp = self._client.upsert( + collection_name=self._write_config.collection_name, + partition_name=self._write_config.partition_name, + data=documents, + timeout=self._write_config.timeout, + **self._write_config.kwargs) + + # Try to flush, but handle connection issues gracefully. + try: + self._client.flush(self._write_config.collection_name) + except Exception as e: + # If flush fails due to connection issues, log but don't fail the write. + _LOGGER.warning( + "Flush operation failed, but upsert was successful: %s", e) + + _LOGGER.debug( + "Upserted into Milvus: upsert_count=%d, cost=%d", + resp.get("upsert_count", 0), + resp.get("cost", 0)) + except Exception as e: + _LOGGER.error("Failed to write to Milvus: %s", e) + raise + + def __enter__(self): + """Enters the context manager and establishes Milvus connection. + + Returns: + Self, enabling use in 'with' statements. + """ + if not self._client: + connection_params = unpack_dataclass_with_kwargs(self._connection_params) + + # Extract retry parameters from connection_params. + max_retries = connection_params.pop('max_retries', 3) + retry_delay = connection_params.pop('retry_delay', 1.0) + retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0) + + def create_client(): + return MilvusClient(**connection_params) + + self._client = retry_with_backoff( + create_client, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff_factor=retry_backoff_factor, + operation_name="Milvus connection", + exception_types=(MilvusException, )) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exits the context manager and closes the Milvus connection. + + Args: + exc_type: Exception type if an exception was raised. + exc_val: Exception value if an exception was raised. + exc_tb: Exception traceback if an exception was raised. + """ + _ = exc_type, exc_val, exc_tb # Unused parameters + if self._client: + self._client.close() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py new file mode 100644 index 000000000000..2c966640dde1 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -0,0 +1,642 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import platform +import unittest +import uuid +from typing import Callable +from typing import cast + +import pytest +from pymilvus import CollectionSchema +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import MilvusClient +from pymilvus.exceptions import MilvusException +from pymilvus.milvus_client import IndexParams + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.test_utils import MilvusTestHelpers +from apache_beam.ml.rag.test_utils import VectorDBContainerInfo +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.testing.test_pipeline import TestPipeline + +try: + from apache_beam.ml.rag.ingestion.milvus_search import ( + MilvusWriteConfig, MilvusVectorWriterConfig) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +def _construct_index_params(): + index_params = IndexParams() + + # Dense vector index for dense embeddings. + index_params.add_index( + field_name="embedding", + index_name="embedding_ivf_flat", + index_type="IVF_FLAT", + metric_type="COSINE", + params={"nlist": 1}) + + # Sparse vector index for sparse embeddings. + index_params.add_index( + field_name="sparse_embedding", + index_name="sparse_embedding_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + params={"inverted_index_algo": "TAAT_NAIVE"}) + + return index_params + + +MILVUS_INGESTION_IT_CONFIG = { + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "index": _construct_index_params, + "corpus": [ + Chunk( + id=1, # type: ignore[arg-type] + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + ), + Chunk( + id=2, # type: ignore[arg-type] + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding=([2, 3], [0.3, 0.1]), + ), + ), + Chunk( + id=3, # type: ignore[arg-type] + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding=([3, 4], [0.4, 0.2]), + ), + ) + ] +} + + +def create_collection_with_partition( + client: MilvusClient, + collection_name: str, + partition_name: str = '', + fields=None): + + if fields is None: + fields = MILVUS_INGESTION_IT_CONFIG["fields"] + + # Configure schema. + schema = CollectionSchema(fields=fields) + + # Configure index. + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_INGESTION_IT_CONFIG["index"]) + + # Create collection with schema. + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Create partition within the collection. + client.create_partition( + collection_name=collection_name, partition_name=partition_name) + + msg = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), msg + + msg = f"Expected partition '{partition_name}' to be created." + assert client.has_partition(collection_name, partition_name), msg + + # Release the collection from memory. We don't need that on pure writing. + client.release_collection(collection_name) + + +def drop_collection(client: MilvusClient, collection_name: str): + try: + client.drop_collection(collection_name) + assert not client.has_collection(collection_name) + except Exception: + # Silently ignore connection errors during cleanup. + pass + + +@pytest.mark.require_docker_in_docker +@unittest.skipUnless( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Integration tests for Milvus vector database ingestion functionality""" + + _db: VectorDBContainerInfo + + @classmethod + def setUpClass(cls): + cls._db = MilvusTestHelpers.start_db_container() + cls._connection_config = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_name=cls._db.id, + token=cls._db.token) + + @classmethod + def tearDownClass(cls): + MilvusTestHelpers.stop_db_container(cls._db) + cls._db = None + + def setUp(self): + self.write_test_pipeline = TestPipeline() + self.write_test_pipeline.not_use_test_runner_api = True + self._collection_name = f"test_collection_{self._testMethodName}" + self._partition_name = f"test_partition_{self._testMethodName}" + config = unpack_dataclass_with_kwargs(self._connection_config) + config["alias"] = f"milvus_conn_{uuid.uuid4().hex[:8]}" + + # Use retry_with_backoff for test client connection. + def create_client(): + return MilvusClient(**config) + + self._test_client = retry_with_backoff( + create_client, + max_retries=3, + retry_delay=1.0, + operation_name="Test Milvus client connection", + exception_types=(MilvusException, )) + + create_collection_with_partition( + self._test_client, self._collection_name, self._partition_name) + + def tearDown(self): + drop_collection(self._test_client, self._collection_name) + self._test_client.close() + + def test_invalid_write_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=non_existent_collection, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + ) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("can't find collection", str(context.exception).lower()) + + def test_invalid_write_on_non_existent_partition(self): + non_existent_partition = "nonexistent_partition" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=non_existent_partition, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("partition not found", str(context.exception).lower()) + + def test_invalid_write_on_missing_primary_key_in_entity(self): + test_chunks = [ + Chunk( + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + + # Deliberately remove id primary key from the entity. + specs = MilvusVectorWriterConfig.default_column_specs() + for i, spec in enumerate(specs): + if spec.column_name == "id": + del specs[i] + break + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=specs) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn( + "insert missed an field `id` to collection", + str(context.exception).lower()) + + def test_write_on_auto_id_primary_key(self): + auto_id_collection = f"auto_id_collection_{self._testMethodName}" + auto_id_partition = f"auto_id_partition_{self._testMethodName}" + auto_id_fields = [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ] + + # Create collection with an auto id field. + create_collection_with_partition( + client=self._test_client, + collection_name=auto_id_collection, + partition_name=auto_id_partition, + fields=auto_id_fields) + + test_chunks = [ + Chunk( + id=1, + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=auto_id_collection, + partition_name=auto_id_partition, + write_config=WriteConfig(write_batch_size=1)) + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + self._test_client.flush(auto_id_collection) + self._test_client.load_collection(auto_id_collection) + result = self._test_client.query( + collection_name=auto_id_collection, + partition_names=[auto_id_partition], + limit=3) + + # Test there is only one item in the result and the ID is not equal to one. + self.assertEqual(len(result), len(test_chunks)) + result_item = dict(result[0]) + self.assertNotEqual(result_item["id"], 1) + + def test_write_on_existent_collection_with_default_schema(self): + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=3)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each chunk was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + self.assertEqual( + result_item["content"], + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content) + self.assertEqual(result_item["metadata"], chunk.metadata) + + # Verify embedding is present and has correct length. + expected_embedding = chunk.embedding.dense_embedding + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + def test_write_with_custom_column_specifications(self): + from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec + from apache_beam.ml.rag.utils import MilvusHelpers + + custom_column_specs = [ + ColumnSpec("id", int, lambda chunk: int(chunk.id) if chunk.id else 0), + ColumnSpec( + "content", + str, lambda chunk: ( + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content)), + ColumnSpec("metadata", dict, lambda chunk: chunk.metadata or {}), + ColumnSpec( + "embedding", + list, lambda chunk: chunk.embedding.dense_embedding or []), + ColumnSpec( + "sparse_embedding", + dict, lambda chunk: ( + MilvusHelpers.sparse_embedding( + chunk.embedding.sparse_embedding) if chunk.embedding and + chunk.embedding.sparse_embedding else {})) + ] + + test_chunks = [ + Chunk( + id=10, + content=Content(text="Custom column spec test"), + embedding=Embedding( + dense_embedding=[0.8, 0.9, 1.0], + sparse_embedding=([1, 3, 5], [0.8, 0.9, 1.0])), + metadata={"custom": "spec_test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=custom_column_specs) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + filter="id == 10", + limit=1) + + self.assertEqual(len(result), 1) + result_item = result[0] + + # Verify custom column specs worked correctly. + self.assertEqual(result_item["id"], 10) + self.assertEqual(result_item["content"], "Custom column spec test") + self.assertEqual(result_item["metadata"], {"custom": "spec_test"}) + + # Verify embedding is present and has correct length. + expected_embedding = [0.8, 0.9, 1.0] + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + # Verify sparse embedding was converted correctly - check keys are present. + expected_sparse_keys = {1, 3, 5} + actual_sparse = result_item["sparse_embedding"] + self.assertIsNotNone(actual_sparse) + self.assertEqual(set(actual_sparse.keys()), expected_sparse_keys) + + def test_write_with_batching(self): + test_chunks = [ + Chunk( + id=i, + content=Content(text=f"Batch test document {i}"), + embedding=Embedding( + dense_embedding=[0.1 * i, 0.2 * i, 0.3 * i], + sparse_embedding=([i, i + 1], [0.1 * i, 0.2 * i])), + metadata={"batch_id": i}) for i in range(1, 8) # 7 chunks + ] + + # Set small batch size to force batching (7 chunks with batch size 3). + batch_write_config = WriteConfig(write_batch_size=3) + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=batch_write_config) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify all data was written successfully. + # Flush to persist all data to disk, then load collection for querying. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each batch was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + + # Verify content and metadata. + self.assertEqual(result_item["content"], chunk.content.text) + self.assertEqual(result_item["metadata"], chunk.metadata) + + # Verify embeddings are present and have correct length. + expected_embedding = chunk.embedding.dense_embedding + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + # Verify sparse embedding keys are present. + expected_sparse_keys = {chunk.id, chunk.id + 1} + actual_sparse = result_item["sparse_embedding"] + self.assertIsNotNone(actual_sparse) + self.assertEqual(set(actual_sparse.keys()), expected_sparse_keys) + + def test_idempotent_write(self): + # Step 1: Insert initial data that doesn't exist. + initial_chunks = [ + Chunk( + id=100, + content=Content(text="Initial document"), + embedding=Embedding( + dense_embedding=[1.0, 2.0, 3.0], + sparse_embedding=([100, 101], [1.0, 2.0])), + metadata={"version": 1}), + Chunk( + id=200, + content=Content(text="Another initial document"), + embedding=Embedding( + dense_embedding=[2.0, 3.0, 4.0], + sparse_embedding=([200, 201], [2.0, 3.0])), + metadata={"version": 1}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=2)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Insert initial data. + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | "Create initial" >> beam.Create(initial_chunks) + | "Write initial" >> config.create_write_transform()) + + # Verify initial data was inserted (not existed before). + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), 2) + result_by_id = {item["id"]: item for item in result} + + # Verify initial state. + self.assertEqual(result_by_id[100]["content"], "Initial document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 1}) + self.assertEqual(result_by_id[200]["content"], "Another initial document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 1}) + + # Step 2: Update existing data (same IDs, different content). + updated_chunks = [ + Chunk( + id=100, + content=Content(text="Updated document"), + embedding=Embedding( + dense_embedding=[1.1, 2.1, 3.1], + sparse_embedding=([100, 102], [1.1, 2.1])), + metadata={"version": 2}), + Chunk( + id=200, + content=Content(text="Another updated document"), + embedding=Embedding( + dense_embedding=[2.1, 3.1, 4.1], + sparse_embedding=([200, 202], [2.1, 3.1])), + metadata={"version": 2}) + ] + + # Perform first update. + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | "Create update1" >> beam.Create(updated_chunks) + | "Write update1" >> config.create_write_transform()) + + # Verify update worked. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), 2) # Still only 2 records. + result_by_id = {item["id"]: item for item in result} + + # Verify updated state. + self.assertEqual(result_by_id[100]["content"], "Updated document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 2}) + self.assertEqual(result_by_id[200]["content"], "Another updated document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 2}) + + # Step 3: Repeat the same update operation 3 more times (idempotence test). + for i in range(3): + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | f"Create repeat{i+2}" >> beam.Create(updated_chunks) + | f"Write repeat{i+2}" >> config.create_write_transform()) + + # Verify state hasn't changed after repeated updates. + self._test_client.flush(self._collection_name) + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + # Still only 2 records. + self.assertEqual(len(result), 2) + result_by_id = {item["id"]: item for item in result} + + # Final state should remain unchanged. + self.assertEqual(result_by_id[100]["content"], "Updated document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 2}) + self.assertEqual(result_by_id[200]["content"], "Another updated document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 2}) + + # Verify embeddings are still correct. + self.assertIsNotNone(result_by_id[100]["embedding"]) + self.assertEqual(len(result_by_id[100]["embedding"]), 3) + self.assertIsNotNone(result_by_id[200]["embedding"]) + self.assertEqual(len(result_by_id[200]["embedding"]), 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py new file mode 100644 index 000000000000..ea80f2a8afcb --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +try: + from apache_beam.ml.rag.ingestion.milvus_search import ( + MilvusWriteConfig, MilvusVectorWriterConfig) + from apache_beam.ml.rag.utils import MilvusConnectionParameters +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +class TestMilvusWriteConfig(unittest.TestCase): + """Unit tests for MilvusWriteConfig validation errors.""" + def test_empty_collection_name_raises_error(self): + """Test that empty collection name raises ValueError.""" + with self.assertRaises(ValueError) as context: + MilvusWriteConfig(collection_name="") + + self.assertIn("Collection name must be provided", str(context.exception)) + + def test_none_collection_name_raises_error(self): + """Test that None collection name raises ValueError.""" + with self.assertRaises(ValueError) as context: + MilvusWriteConfig(collection_name=None) + + self.assertIn("Collection name must be provided", str(context.exception)) + + +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Unit tests for MilvusVectorWriterConfig validation and functionality.""" + def test_valid_config_creation(self): + """Test creation of valid MilvusVectorWriterConfig.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + self.assertEqual(config.connection_params, connection_params) + self.assertEqual(config.write_config, write_config) + self.assertIsNotNone(config.column_specs) + + def test_create_converter_returns_callable(self): + """Test that create_converter returns a callable function.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + converter = config.create_converter() + self.assertTrue(callable(converter)) + + def test_create_write_transform_returns_ptransform(self): + """Test that create_write_transform returns a PTransform.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + transform = config.create_write_transform() + self.assertIsNotNone(transform) + + def test_default_column_specs_has_expected_fields(self): + """Test that default column specs include expected fields.""" + column_specs = MilvusVectorWriterConfig.default_column_specs() + + self.assertIsInstance(column_specs, list) + self.assertGreater(len(column_specs), 0) + + column_names = [spec.column_name for spec in column_specs] + expected_fields = [ + "id", "embedding", "sparse_embedding", "content", "metadata" + ] + + for field in expected_fields: + self.assertIn(field, column_names) + + @parameterized.expand([ + # Invalid connection parameters - empty URI. + ( + lambda: ( + MilvusConnectionParameters(uri=""), MilvusWriteConfig( + collection_name="test_collection")), + "URI must be provided"), + # Invalid write config - empty collection name. + ( + lambda: ( + MilvusConnectionParameters(uri="http://localhost:19530"), + MilvusWriteConfig(collection_name="")), + "Collection name must be provided"), + ]) + def test_invalid_configuration_parameters( + self, create_params, expected_error_msg): + """Test validation errors for invalid configuration parameters.""" + with self.assertRaises(ValueError) as context: + connection_params, write_config = create_params() + MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + self.assertIn(expected_error_msg, str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py index eca740a4e9c3..cecebbb4455b 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -22,6 +22,7 @@ from typing import List from typing import Literal from typing import Optional +from typing import Tuple from typing import Type from typing import Union @@ -30,16 +31,16 @@ def chunk_embedding_fn(chunk: Chunk) -> str: """Convert embedding to PostgreSQL array string. - + Formats dense embedding as a PostgreSQL-compatible array string. Example: [1.0, 2.0] -> '{1.0,2.0}' - + Args: chunk: Input Chunk object. - + Returns: str: PostgreSQL array string representation of the embedding. - + Raises: ValueError: If chunk has no dense embedding. """ @@ -51,7 +52,7 @@ def chunk_embedding_fn(chunk: Chunk) -> str: @dataclass class ColumnSpec: """Specification for mapping Chunk fields to SQL columns for insertion. - + Defines how to extract and format values from Chunks into database columns, handling the full pipeline from Python value to SQL insertion. @@ -71,7 +72,7 @@ class ColumnSpec: Common examples: - "::float[]" for vector arrays - "::jsonb" for JSON data - + Examples: Basic text column (uses standard JDBC type mapping): >>> ColumnSpec.text( @@ -83,7 +84,7 @@ class ColumnSpec: Vector column with explicit array casting: >>> ColumnSpec.vector( ... column_name="embedding", - ... value_fn=lambda chunk: '{' + + ... value_fn=lambda chunk: '{' + ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' ... ) # Results in: INSERT INTO table (embedding) VALUES (?::float[]) @@ -168,17 +169,17 @@ def with_id_spec( convert_fn: Optional[Callable[[str], Any]] = None, sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': """Add ID :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the ID column (defaults to "id") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the chunk ID If None, uses ID as-is sql_typecast: Optional SQL type cast - + Returns: Self for method chaining - + Example: >>> builder.with_id_spec( ... column_name="doc_id", @@ -205,17 +206,17 @@ def with_content_spec( convert_fn: Optional[Callable[[str], Any]] = None, sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': """Add content :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the content column (defaults to "content") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the content text If None, uses content text as-is sql_typecast: Optional SQL type cast - + Returns: Self for method chaining - + Example: >>> builder.with_content_spec( ... column_name="content_length", @@ -244,17 +245,17 @@ def with_metadata_spec( convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None, sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder': """Add metadata :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the metadata column (defaults to "metadata") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the metadata dictionary If None and python_type is str, converts to JSON string sql_typecast: Optional SQL type cast (defaults to "::jsonb") - + Returns: Self for method chaining - + Example: >>> builder.with_metadata_spec( ... column_name="meta_tags", @@ -283,19 +284,19 @@ def with_embedding_spec( convert_fn: Optional[Callable[[List[float]], Any]] = None ) -> 'ColumnSpecsBuilder': """Add embedding :class:`.ColumnSpec` with optional conversion. - + Args: column_name: Name for the embedding column (defaults to "embedding") convert_fn: Optional function to convert the dense embedding values If None, uses default PostgreSQL array format - + Returns: Self for method chaining - + Example: >>> builder.with_embedding_spec( ... column_name="embedding_vector", - ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" + ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" ... for x in values) + '}' ... ) """ @@ -311,6 +312,39 @@ def value_fn(chunk: Chunk) -> Any: ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) return self + def with_sparse_embedding_spec( + self, + column_name: str = "sparse_embedding", + conv_fn: Optional[Callable[[Tuple[List[int], List[float]]], Any]] = None + ) -> 'ColumnSpecsBuilder': + """Add sparse embedding :class:`.ColumnSpec` with optional conversion. + Args: + column_name: Name for the sparse embedding column + (defaults to "sparse_embedding") + conv_fn: Optional function to convert the sparse embedding tuple + If None, converts to PostgreSQL-compatible JSON format + Returns: + Self for method chaining + Example: + >>> builder.with_sparse_embedding_spec( + ... column_name="sparse_vector", + ... convert_fn=lambda sparse: dict(zip(sparse[0], sparse[1])) + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.embedding is None or chunk.embedding.sparse_embedding is None: + raise ValueError(f'Expected chunk to contain sparse embedding. {chunk}') + sparse_embedding = chunk.embedding.sparse_embedding + if conv_fn: + return conv_fn(sparse_embedding) + # Default: convert to dict format for JSON storage. + indices, values = sparse_embedding + return json.dumps(dict(zip(indices, values))) + + self._specs.append( + ColumnSpec.jsonb(column_name=column_name, value_fn=value_fn)) + return self + def add_metadata_field( self, field: str, @@ -330,7 +364,7 @@ def add_metadata_field( desired type. If None, value is used as-is default: Default value if field is missing from metadata sql_typecast: Optional SQL type cast (e.g. "::timestamp") - + Returns: Self for chaining @@ -385,17 +419,17 @@ def value_fn(chunk: Chunk) -> Any: def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder': """Add a custom :class:`.ColumnSpec` to the builder. - + Use this method when you need complete control over the :class:`.ColumnSpec` , including custom value extraction and type handling. - + Args: spec: A :class:`.ColumnSpec` instance defining the column name, type, value extraction, and optional SQL type casting. - + Returns: Self for method chaining - + Examples: Custom text column from chunk metadata: @@ -430,12 +464,12 @@ class ConflictResolution: IGNORE: Skips conflicting records. update_fields: Optional list of fields to update on conflict. If None, all non-conflict fields are updated. - + Examples: Simple primary key: >>> ConflictResolution("id") - + Composite key with specific update fields: >>> ConflictResolution( @@ -443,7 +477,7 @@ class ConflictResolution: ... action="UPDATE", ... update_fields=["embedding", "content"] ... ) - + Ignore conflicts: >>> ConflictResolution( From 39925859e5b5eeb4d5b08f311bd40e7f8dd4a430 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Wed, 5 Nov 2025 06:23:09 +0000 Subject: [PATCH 2/6] CHANGES.md: update release notes --- CHANGES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 5b365e15fdb4..fccf59704f45 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -75,6 +75,9 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Python examples added for Milvus search enrichment handler on [Beam Website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-milvus/) including jupyter notebook example (Python) ([#36176](https://github.com/apache/beam/issues/36176)). +* Milvus sink I/O connector added (Python) ([#36702]( + https://github.com/apache/beam/issues/36702)). Now Beam has full support for + Milvus integration including Milvus enrichment and sink operations. ## Breaking Changes From 43c5da20fc1ff9463515be1467cb64f97f2b1d69 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Thu, 13 Nov 2025 14:03:50 +0000 Subject: [PATCH 3/6] sdks/python: fix py docs formatting issues --- sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py index cecebbb4455b..93968564f156 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -318,13 +318,16 @@ def with_sparse_embedding_spec( conv_fn: Optional[Callable[[Tuple[List[int], List[float]]], Any]] = None ) -> 'ColumnSpecsBuilder': """Add sparse embedding :class:`.ColumnSpec` with optional conversion. + Args: column_name: Name for the sparse embedding column (defaults to "sparse_embedding") conv_fn: Optional function to convert the sparse embedding tuple If None, converts to PostgreSQL-compatible JSON format + Returns: Self for method chaining + Example: >>> builder.with_sparse_embedding_spec( ... column_name="sparse_vector", From 46a03c8382637225d55b2f1496bd18e91efb320c Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Thu, 13 Nov 2025 16:43:00 +0000 Subject: [PATCH 4/6] sdks/python: fix linting issues --- .../apache_beam/ml/rag/ingestion/milvus_search_it_test.py | 4 ++-- .../python/apache_beam/ml/rag/ingestion/milvus_search_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py index 2c966640dde1..724782cc25ca 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -42,8 +42,8 @@ from apache_beam.testing.test_pipeline import TestPipeline try: - from apache_beam.ml.rag.ingestion.milvus_search import ( - MilvusWriteConfig, MilvusVectorWriterConfig) + from apache_beam.ml.rag.ingestion.milvus_search import MilvusVectorWriterConfig + from apache_beam.ml.rag.ingestion.milvus_search import MilvusWriteConfig except ImportError as e: raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py index ea80f2a8afcb..80d55ac9382c 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py @@ -19,8 +19,8 @@ from parameterized import parameterized try: - from apache_beam.ml.rag.ingestion.milvus_search import ( - MilvusWriteConfig, MilvusVectorWriterConfig) + from apache_beam.ml.rag.ingestion.milvus_search import MilvusVectorWriterConfig + from apache_beam.ml.rag.ingestion.milvus_search import MilvusWriteConfig from apache_beam.ml.rag.utils import MilvusConnectionParameters except ImportError as e: raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') From 8ac2c425c3acdbc7a6e0ddd37c17a7d47caa816d Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 21 Nov 2025 12:15:11 +0000 Subject: [PATCH 5/6] sdks/python: delegate auto-flushing to milvus backend --- .../ml/rag/ingestion/milvus_search.py | 43 +++++++------------ 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py index e019a03d7514..c73aba5f42e4 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py @@ -181,7 +181,7 @@ def expand(self, pcoll: beam.PCollection[Chunk]): pcoll: PCollection of Chunk objects to write to Milvus. Returns: - PCollection of the same Chunk objects after writing to Milvus. + PCollection of dictionaries representing the records written to Milvus. """ return ( pcoll @@ -292,33 +292,20 @@ def write(self, documents): documents: List of dictionaries representing Milvus records to write. Each dictionary should contain fields matching the collection schema. """ - if not self._client: - self._client = MilvusClient( - **unpack_dataclass_with_kwargs(self._connection_params)) - - try: - resp = self._client.upsert( - collection_name=self._write_config.collection_name, - partition_name=self._write_config.partition_name, - data=documents, - timeout=self._write_config.timeout, - **self._write_config.kwargs) - - # Try to flush, but handle connection issues gracefully. - try: - self._client.flush(self._write_config.collection_name) - except Exception as e: - # If flush fails due to connection issues, log but don't fail the write. - _LOGGER.warning( - "Flush operation failed, but upsert was successful: %s", e) - - _LOGGER.debug( - "Upserted into Milvus: upsert_count=%d, cost=%d", - resp.get("upsert_count", 0), - resp.get("cost", 0)) - except Exception as e: - _LOGGER.error("Failed to write to Milvus: %s", e) - raise + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) + + resp = self._client.upsert( + collection_name=self._write_config.collection_name, + partition_name=self._write_config.partition_name, + data=documents, + timeout=self._write_config.timeout, + **self._write_config.kwargs) + + _LOGGER.debug( + "Upserted into Milvus: upsert_count=%d, cost=%d", + resp.get("upsert_count", 0), + resp.get("cost", 0)) def __enter__(self): """Enters the context manager and establishes Milvus connection. From d0b68deda927052c3104ac46bc17b9df981f7b40 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 21 Nov 2025 12:15:50 +0000 Subject: [PATCH 6/6] sdks/python: address gemini comments --- CHANGES.md | 5 ++--- .../ml/rag/ingestion/milvus_search_it_test.py | 11 ++--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 281e46c71898..e6f9cf13ff91 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -75,9 +75,8 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Python examples added for Milvus search enrichment handler on [Beam Website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-milvus/) including jupyter notebook example (Python) ([#36176](https://github.com/apache/beam/issues/36176)). -* Milvus sink I/O connector added (Python) ([#36702]( - https://github.com/apache/beam/issues/36702)). Now Beam has full support for - Milvus integration including Milvus enrichment and sink operations. +* Milvus sink I/O connector added (Python) ([#36702](https://github.com/apache/beam/issues/36702)). +Now Beam has full support for Milvus integration including Milvus enrichment and sink operations. ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py index 724782cc25ca..38b497e8fa71 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -370,10 +370,7 @@ def test_write_on_existent_collection_with_default_schema(self): for chunk in test_chunks: self.assertIn(chunk.id, result_by_id) result_item = result_by_id[chunk.id] - self.assertEqual( - result_item["content"], - chunk.content.text - if hasattr(chunk.content, 'text') else chunk.content) + self.assertEqual(result_item["content"], chunk.content.text) self.assertEqual(result_item["metadata"], chunk.metadata) # Verify embedding is present and has correct length. @@ -388,11 +385,7 @@ def test_write_with_custom_column_specifications(self): custom_column_specs = [ ColumnSpec("id", int, lambda chunk: int(chunk.id) if chunk.id else 0), - ColumnSpec( - "content", - str, lambda chunk: ( - chunk.content.text - if hasattr(chunk.content, 'text') else chunk.content)), + ColumnSpec("content", str, lambda chunk: chunk.content.text), ColumnSpec("metadata", dict, lambda chunk: chunk.metadata or {}), ColumnSpec( "embedding",