diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 99a8fc8ff6d5..0b27d14486b8 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", + "pr": "36654", "modification": 14 } diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py deleted file mode 100644 index f79db470bca4..000000000000 --- a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py +++ /dev/null @@ -1,646 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cloud Spanner vector store writer for RAG pipelines. - -This module provides a writer for storing embeddings and associated metadata -in Google Cloud Spanner. It supports flexible schema configuration with the -ability to flatten metadata fields into dedicated columns. - -Example usage: - - Default schema (id, embedding, content, metadata): - >>> config = SpannerVectorWriterConfig( - ... project_id="my-project", - ... instance_id="my-instance", - ... database_id="my-db", - ... table_name="embeddings" - ... ) - - Flattened metadata fields: - >>> specs = ( - ... SpannerColumnSpecsBuilder() - ... .with_id_spec() - ... .with_embedding_spec() - ... .with_content_spec() - ... .add_metadata_field("source", str) - ... .add_metadata_field("page_number", int, default=0) - ... .with_metadata_spec() - ... .build() - ... ) - >>> config = SpannerVectorWriterConfig( - ... project_id="my-project", - ... instance_id="my-instance", - ... database_id="my-db", - ... table_name="embeddings", - ... column_specs=specs - ... ) - -Spanner schema example: - - CREATE TABLE embeddings ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>768), - content STRING(MAX), - source STRING(MAX), - page_number INT64, - metadata JSON - ) PRIMARY KEY (id) -""" - -import functools -import json -from dataclasses import dataclass -from typing import Any -from typing import Callable -from typing import List -from typing import Literal -from typing import NamedTuple -from typing import Optional -from typing import Type - -import apache_beam as beam -from apache_beam.coders import registry -from apache_beam.coders.row_coder import RowCoder -from apache_beam.io.gcp import spanner -from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig -from apache_beam.ml.rag.types import Chunk - - -@dataclass -class SpannerColumnSpec: - """Column specification for Spanner vector writes. - - Defines how to extract and format values from Chunks for insertion into - Spanner table columns. Each spec maps to one column in the target table. - - Attributes: - column_name: Name of the Spanner table column - python_type: Python type for the NamedTuple field (required for RowCoder) - value_fn: Function to extract value from a Chunk - - Examples: - String column: - >>> SpannerColumnSpec( - ... column_name="id", - ... python_type=str, - ... value_fn=lambda chunk: chunk.id - ... ) - - Array column with conversion: - >>> SpannerColumnSpec( - ... column_name="embedding", - ... python_type=List[float], - ... value_fn=lambda chunk: chunk.embedding.dense_embedding - ... ) - """ - column_name: str - python_type: Type - value_fn: Callable[[Chunk], Any] - - -def _extract_and_convert(extract_fn, convert_fn, chunk): - if convert_fn: - return convert_fn(extract_fn(chunk)) - return extract_fn(chunk) - - -class SpannerColumnSpecsBuilder: - """Builder for creating Spanner column specifications. - - Provides a fluent API for defining table schemas and how to populate them - from Chunk objects. Supports standard Chunk fields (id, embedding, content, - metadata) and flattening metadata fields into dedicated columns. - - Example: - >>> specs = ( - ... SpannerColumnSpecsBuilder() - ... .with_id_spec() - ... .with_embedding_spec() - ... .with_content_spec() - ... .add_metadata_field("source", str) - ... .with_metadata_spec() - ... .build() - ... ) - """ - def __init__(self): - self._specs: List[SpannerColumnSpec] = [] - - @staticmethod - def with_defaults() -> 'SpannerColumnSpecsBuilder': - """Create builder with default schema. - - Default schema includes: - - id (STRING): Chunk ID - - embedding (ARRAY): Dense embedding vector - - content (STRING): Chunk content text - - metadata (JSON): Full metadata as JSON - - Returns: - Builder with default column specifications - """ - return ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). - with_content_spec().with_metadata_spec()) - - def with_id_spec( - self, - column_name: str = "id", - python_type: Type = str, - convert_fn: Optional[Callable[[str], Any]] = None - ) -> 'SpannerColumnSpecsBuilder': - """Add ID column specification. - - Args: - column_name: Column name (default: "id") - python_type: Python type (default: str) - convert_fn: Optional converter (e.g., to cast to int) - - Returns: - Self for method chaining - - Examples: - Default string ID: - >>> builder.with_id_spec() - - Integer ID with conversion: - >>> builder.with_id_spec( - ... python_type=int, - ... convert_fn=lambda id: int(id.split('_')[1]) - ... ) - """ - - self._specs.append( - SpannerColumnSpec( - column_name=column_name, - python_type=python_type, - value_fn=functools.partial( - _extract_and_convert, lambda chunk: chunk.id, convert_fn))) - return self - - def with_embedding_spec( - self, - column_name: str = "embedding", - convert_fn: Optional[Callable[[List[float]], List[float]]] = None - ) -> 'SpannerColumnSpecsBuilder': - """Add embedding array column (ARRAY or ARRAY). - - Args: - column_name: Column name (default: "embedding") - convert_fn: Optional converter (e.g., normalize, quantize) - - Returns: - Self for method chaining - - Examples: - Default embedding: - >>> builder.with_embedding_spec() - - Normalized embedding: - >>> def normalize(vec): - ... norm = (sum(x**2 for x in vec) ** 0.5) or 1.0 - ... return [x/norm for x in vec] - >>> builder.with_embedding_spec(convert_fn=normalize) - - Rounded precision: - >>> builder.with_embedding_spec( - ... convert_fn=lambda vec: [round(x, 4) for x in vec] - ... ) - """ - def extract_fn(chunk: Chunk) -> List[float]: - if chunk.embedding is None or chunk.embedding.dense_embedding is None: - raise ValueError(f'Chunk must contain embedding: {chunk}') - return chunk.embedding.dense_embedding - - self._specs.append( - SpannerColumnSpec( - column_name=column_name, - python_type=List[float], - value_fn=functools.partial( - _extract_and_convert, extract_fn, convert_fn))) - return self - - def with_content_spec( - self, - column_name: str = "content", - python_type: Type = str, - convert_fn: Optional[Callable[[str], Any]] = None - ) -> 'SpannerColumnSpecsBuilder': - """Add content column. - - Args: - column_name: Column name (default: "content") - python_type: Python type (default: str) - convert_fn: Optional converter - - Returns: - Self for method chaining - - Examples: - Default text content: - >>> builder.with_content_spec() - - Content length as integer: - >>> builder.with_content_spec( - ... column_name="content_length", - ... python_type=int, - ... convert_fn=lambda text: len(text.split()) - ... ) - - Truncated content: - >>> builder.with_content_spec( - ... convert_fn=lambda text: text[:1000] - ... ) - """ - def extract_fn(chunk: Chunk) -> str: - if chunk.content.text is None: - raise ValueError(f'Chunk must contain content: {chunk}') - return chunk.content.text - - self._specs.append( - SpannerColumnSpec( - column_name=column_name, - python_type=python_type, - value_fn=functools.partial( - _extract_and_convert, extract_fn, convert_fn))) - return self - - def with_metadata_spec( - self, column_name: str = "metadata") -> 'SpannerColumnSpecsBuilder': - """Add metadata JSON column. - - Stores the full metadata dictionary as a JSON string in Spanner. - - Args: - column_name: Column name (default: "metadata") - - Returns: - Self for method chaining - - Note: - Metadata is automatically converted to JSON string using json.dumps() - """ - value_fn = lambda chunk: json.dumps(chunk.metadata) - self._specs.append( - SpannerColumnSpec( - column_name=column_name, python_type=str, value_fn=value_fn)) - return self - - def add_metadata_field( - self, - field: str, - python_type: Type, - column_name: Optional[str] = None, - convert_fn: Optional[Callable[[Any], Any]] = None, - default: Any = None) -> 'SpannerColumnSpecsBuilder': - """Flatten a metadata field into its own column. - - Extracts a specific field from chunk.metadata and stores it in a - dedicated table column. - - Args: - field: Key in chunk.metadata to extract - python_type: Python type (must be explicitly specified) - column_name: Column name (default: same as field) - convert_fn: Optional converter for type casting/transformation - default: Default value if field is missing from metadata - - Returns: - Self for method chaining - - Examples: - String field: - >>> builder.add_metadata_field("source", str) - - Integer with default: - >>> builder.add_metadata_field( - ... "page_number", - ... int, - ... default=0 - ... ) - - Float with conversion: - >>> builder.add_metadata_field( - ... "confidence", - ... float, - ... convert_fn=lambda x: round(float(x), 2), - ... default=0.0 - ... ) - - List of strings: - >>> builder.add_metadata_field( - ... "tags", - ... List[str], - ... default=[] - ... ) - - Timestamp with conversion: - >>> builder.add_metadata_field( - ... "created_at", - ... str, - ... convert_fn=lambda ts: ts.isoformat() - ... ) - """ - name = column_name or field - - def value_fn(chunk: Chunk) -> Any: - return chunk.metadata.get(field, default) - - self._specs.append( - SpannerColumnSpec( - column_name=name, - python_type=python_type, - value_fn=functools.partial( - _extract_and_convert, value_fn, convert_fn))) - return self - - def add_column( - self, - column_name: str, - python_type: Type, - value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder': - """Add a custom column with full control. - - Args: - column_name: Column name - python_type: Python type (required) - value_fn: Value extraction function - - Returns: - Self for method chaining - - Examples: - Boolean flag: - >>> builder.add_column( - ... column_name="has_code", - ... python_type=bool, - ... value_fn=lambda chunk: "```" in chunk.content.text - ... ) - - Computed value: - >>> builder.add_column( - ... column_name="word_count", - ... python_type=int, - ... value_fn=lambda chunk: len(chunk.content.text.split()) - ... ) - """ - self._specs.append( - SpannerColumnSpec( - column_name=column_name, python_type=python_type, - value_fn=value_fn)) - return self - - def build(self) -> List[SpannerColumnSpec]: - """Build the final list of column specifications. - - Returns: - List of SpannerColumnSpec objects - """ - return self._specs.copy() - - -class _SpannerSchemaBuilder: - """Internal: Builds NamedTuple schema and registers RowCoder. - - Creates a NamedTuple type from column specifications and registers it - with Beam's RowCoder for serialization. - """ - def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]): - """Initialize schema builder. - - Args: - table_name: Table name (used in NamedTuple type name) - column_specs: List of column specifications - - Raises: - ValueError: If duplicate column names are found - """ - self.table_name = table_name - self.column_specs = column_specs - - # Validate no duplicates - names = [col.column_name for col in column_specs] - duplicates = set(name for name in names if names.count(name) > 1) - if duplicates: - raise ValueError(f"Duplicate column names: {duplicates}") - - # Create NamedTuple type - fields = [(col.column_name, col.python_type) for col in column_specs] - type_name = f"SpannerVectorRecord_{table_name}" - self.record_type = NamedTuple(type_name, fields) # type: ignore - - # Register coder - registry.register_coder(self.record_type, RowCoder) - - def create_converter(self) -> Callable[[Chunk], NamedTuple]: - """Create converter function from Chunk to NamedTuple record. - - Returns: - Function that converts a Chunk to a NamedTuple record - """ - def convert(chunk: Chunk) -> self.record_type: # type: ignore - values = { - col.column_name: col.value_fn(chunk) - for col in self.column_specs - } - return self.record_type(**values) # type: ignore - - return convert - - -class SpannerVectorWriterConfig(VectorDatabaseWriteConfig): - """Configuration for writing vectors to Cloud Spanner. - - Supports flexible schema configuration through column specifications and - provides control over Spanner-specific write parameters. - - Examples: - Default schema: - >>> config = SpannerVectorWriterConfig( - ... project_id="my-project", - ... instance_id="my-instance", - ... database_id="my-db", - ... table_name="embeddings" - ... ) - - Custom schema with flattened metadata: - >>> specs = ( - ... SpannerColumnSpecsBuilder() - ... .with_id_spec() - ... .with_embedding_spec() - ... .with_content_spec() - ... .add_metadata_field("source", str) - ... .add_metadata_field("page_number", int, default=0) - ... .with_metadata_spec() - ... .build() - ... ) - >>> config = SpannerVectorWriterConfig( - ... project_id="my-project", - ... instance_id="my-instance", - ... database_id="my-db", - ... table_name="embeddings", - ... column_specs=specs - ... ) - - With emulator: - >>> config = SpannerVectorWriterConfig( - ... project_id="test-project", - ... instance_id="test-instance", - ... database_id="test-db", - ... table_name="embeddings", - ... emulator_host="http://localhost:9010" - ... ) - """ - def __init__( - self, - project_id: str, - instance_id: str, - database_id: str, - table_name: str, - *, - # Schema configuration - column_specs: Optional[List[SpannerColumnSpec]] = None, - # Write operation type - write_mode: Literal["INSERT", "UPDATE", "REPLACE", - "INSERT_OR_UPDATE"] = "INSERT_OR_UPDATE", - # Batching configuration - max_batch_size_bytes: Optional[int] = None, - max_number_mutations: Optional[int] = None, - max_number_rows: Optional[int] = None, - grouping_factor: Optional[int] = None, - # Networking - host: Optional[str] = None, - emulator_host: Optional[str] = None, - expansion_service: Optional[str] = None, - # Retry/deadline configuration - commit_deadline: Optional[int] = None, - max_cumulative_backoff: Optional[int] = None, - # Error handling - failure_mode: Optional[ - spanner.FailureMode] = spanner.FailureMode.REPORT_FAILURES, - high_priority: bool = False, - # Additional Spanner arguments - **spanner_kwargs): - """Initialize Spanner vector writer configuration. - - Args: - project_id: GCP project ID - instance_id: Spanner instance ID - database_id: Spanner database ID - table_name: Target table name - column_specs: Schema configuration using SpannerColumnSpecsBuilder. - If None, uses default schema (id, embedding, content, metadata) - write_mode: Spanner write operation type: - - INSERT: Fail if row exists - - UPDATE: Fail if row doesn't exist - - REPLACE: Delete then insert - - INSERT_OR_UPDATE: Insert or update if exists (default) - max_batch_size_bytes: Maximum bytes per mutation batch (default: 1MB) - max_number_mutations: Maximum cell mutations per batch (default: 5000) - max_number_rows: Maximum rows per batch (default: 500) - grouping_factor: Multiple of max mutation for sorting (default: 1000) - host: Spanner host URL (usually not needed) - emulator_host: Spanner emulator host (e.g., "http://localhost:9010") - expansion_service: Java expansion service address (host:port) - commit_deadline: Commit API deadline in seconds (default: 15) - max_cumulative_backoff: Max retry backoff seconds (default: 900) - failure_mode: Error handling strategy: - - FAIL_FAST: Throw exception for any failure - - REPORT_FAILURES: Continue processing (default) - high_priority: Use high priority for operations (default: False) - **spanner_kwargs: Additional keyword arguments to pass to the - underlying Spanner write transform. Use this to pass any - Spanner-specific parameters not explicitly exposed by this config. - """ - self.project_id = project_id - self.instance_id = instance_id - self.database_id = database_id - self.table_name = table_name - self.write_mode = write_mode - self.max_batch_size_bytes = max_batch_size_bytes - self.max_number_mutations = max_number_mutations - self.max_number_rows = max_number_rows - self.grouping_factor = grouping_factor - self.host = host - self.emulator_host = emulator_host - self.expansion_service = expansion_service - self.commit_deadline = commit_deadline - self.max_cumulative_backoff = max_cumulative_backoff - self.failure_mode = failure_mode - self.high_priority = high_priority - self.spanner_kwargs = spanner_kwargs - - # Use defaults if not provided - specs = column_specs or SpannerColumnSpecsBuilder.with_defaults().build() - - # Create schema builder (NamedTuple + RowCoder registration) - self.schema_builder = _SpannerSchemaBuilder(table_name, specs) - - def create_write_transform(self) -> beam.PTransform: - """Create the Spanner write PTransform. - - Returns: - PTransform for writing to Spanner - """ - return _WriteToSpannerVectorDatabase(self) - - -class _WriteToSpannerVectorDatabase(beam.PTransform): - """Internal: PTransform for writing to Spanner vector database.""" - def __init__(self, config: SpannerVectorWriterConfig): - """Initialize write transform. - - Args: - config: Spanner writer configuration - """ - self.config = config - self.schema_builder = config.schema_builder - - def expand(self, pcoll: beam.PCollection[Chunk]): - """Expand the transform. - - Args: - pcoll: PCollection of Chunks to write - """ - # Select appropriate Spanner write transform based on write_mode - write_transform_class = { - "INSERT": spanner.SpannerInsert, - "UPDATE": spanner.SpannerUpdate, - "REPLACE": spanner.SpannerReplace, - "INSERT_OR_UPDATE": spanner.SpannerInsertOrUpdate, - }[self.config.write_mode] - - return ( - pcoll - | "Convert to Records" >> beam.Map( - self.schema_builder.create_converter()).with_output_types( - self.schema_builder.record_type) - | "Write to Spanner" >> write_transform_class( - project_id=self.config.project_id, - instance_id=self.config.instance_id, - database_id=self.config.database_id, - table=self.config.table_name, - max_batch_size_bytes=self.config.max_batch_size_bytes, - max_number_mutations=self.config.max_number_mutations, - max_number_rows=self.config.max_number_rows, - grouping_factor=self.config.grouping_factor, - host=self.config.host, - emulator_host=self.config.emulator_host, - commit_deadline=self.config.commit_deadline, - max_cumulative_backoff=self.config.max_cumulative_backoff, - failure_mode=self.config.failure_mode, - expansion_service=self.config.expansion_service, - high_priority=self.config.high_priority, - **self.config.spanner_kwargs)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py deleted file mode 100644 index ab9a982a81f7..000000000000 --- a/sdks/python/apache_beam/ml/rag/ingestion/spanner_it_test.py +++ /dev/null @@ -1,601 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Integration tests for Spanner vector store writer.""" - -import logging -import os -import time -import unittest -import uuid - -import pytest - -import apache_beam as beam -from apache_beam.ml.rag.ingestion.spanner import SpannerVectorWriterConfig -from apache_beam.ml.rag.types import Chunk -from apache_beam.ml.rag.types import Content -from apache_beam.ml.rag.types import Embedding -from apache_beam.testing.test_pipeline import TestPipeline - -# pylint: disable=wrong-import-order, wrong-import-position -try: - from google.cloud import spanner -except ImportError: - spanner = None - -try: - from testcontainers.core.container import DockerContainer -except ImportError: - DockerContainer = None -# pylint: enable=wrong-import-order, wrong-import-position - - -def retry(fn, retries, err_msg, *args, **kwargs): - """Retry a function with exponential backoff.""" - for _ in range(retries): - try: - return fn(*args, **kwargs) - except: # pylint: disable=bare-except - time.sleep(1) - logging.error(err_msg) - raise RuntimeError(err_msg) - - -class SpannerEmulatorHelper: - """Helper for managing Spanner emulator lifecycle.""" - def __init__(self, project_id: str, instance_id: str, table_name: str): - self.project_id = project_id - self.instance_id = instance_id - self.table_name = table_name - self.host = None - - # Start emulator - self.emulator = DockerContainer( - 'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports( - 9010, 9020) - retry(self.emulator.start, 3, 'Could not start spanner emulator.') - time.sleep(3) - - self.host = f'{self.emulator.get_container_host_ip()}:' \ - f'{self.emulator.get_exposed_port(9010)}' - os.environ['SPANNER_EMULATOR_HOST'] = self.host - - # Create client and instance - self.client = spanner.Client(project_id) - self.instance = self.client.instance(instance_id) - self.create_instance() - - def create_instance(self): - """Create Spanner instance in emulator.""" - self.instance.create().result(120) - - def create_database(self, database_id: str): - """Create database with default vector table schema.""" - database = self.instance.database( - database_id, - ddl_statements=[ - f''' - CREATE TABLE {self.table_name} ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>3), - content STRING(MAX), - metadata JSON - ) PRIMARY KEY (id)''' - ]) - database.create().result(120) - - def read_data(self, database_id: str): - """Read all data from the table.""" - database = self.instance.database(database_id) - with database.snapshot() as snapshot: - results = snapshot.execute_sql( - f'SELECT * FROM {self.table_name} ORDER BY id') - return list(results) if results else [] - - def drop_database(self, database_id: str): - """Drop the database.""" - database = self.instance.database(database_id) - database.drop() - - def shutdown(self): - """Stop the emulator.""" - if self.emulator: - try: - self.emulator.stop() - except: # pylint: disable=bare-except - logging.error('Could not stop Spanner emulator.') - - def get_emulator_host(self) -> str: - """Get the emulator host URL.""" - return f'http://{self.host}' - - -@pytest.mark.uses_gcp_java_expansion_service -@unittest.skipUnless( - os.environ.get('EXPANSION_JARS'), - "EXPANSION_JARS environment var is not provided, " - "indicating that jars have not been built") -@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.') -@unittest.skipIf( - DockerContainer is None, 'testcontainers package is not installed.') -class SpannerVectorWriterTest(unittest.TestCase): - """Integration tests for Spanner vector writer.""" - @classmethod - def setUpClass(cls): - """Set up Spanner emulator for all tests.""" - cls.project_id = 'test-project' - cls.instance_id = 'test-instance' - cls.table_name = 'embeddings' - - cls.spanner_helper = SpannerEmulatorHelper( - cls.project_id, cls.instance_id, cls.table_name) - - @classmethod - def tearDownClass(cls): - """Tear down Spanner emulator.""" - cls.spanner_helper.shutdown() - - def setUp(self): - """Create a unique database for each test.""" - self.database_id = f'test_db_{uuid.uuid4().hex}'[:30] - self.spanner_helper.create_database(self.database_id) - - def tearDown(self): - """Drop the test database.""" - self.spanner_helper.drop_database(self.database_id) - - def test_write_default_schema(self): - """Test writing with default schema (id, embedding, content, metadata).""" - # Create test chunks - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='First document'), - metadata={ - 'source': 'test', 'page': 1 - }), - Chunk( - id='doc2', - embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), - content=Content(text='Second document'), - metadata={ - 'source': 'test', 'page': 2 - }), - ] - - # Create config with default schema - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify data was written - results = self.spanner_helper.read_data(self.database_id) - self.assertEqual(len(results), 2) - - # Check first row - row1 = results[0] - self.assertEqual(row1[0], 'doc1') # id - self.assertEqual(list(row1[1]), [1.0, 2.0, 3.0]) # embedding - self.assertEqual(row1[2], 'First document') # content - # metadata is JSON - metadata1 = row1[3] - self.assertEqual(metadata1['source'], 'test') - self.assertEqual(metadata1['page'], 1) - - # Check second row - row2 = results[1] - self.assertEqual(row2[0], 'doc2') - self.assertEqual(list(row2[1]), [4.0, 5.0, 6.0]) - self.assertEqual(row2[2], 'Second document') - - def test_write_flattened_metadata(self): - """Test writing with flattened metadata fields.""" - from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder - - # Create custom database with flattened columns - self.spanner_helper.drop_database(self.database_id) - database = self.spanner_helper.instance.database( - self.database_id, - ddl_statements=[ - f''' - CREATE TABLE {self.table_name} ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>3), - content STRING(MAX), - source STRING(MAX), - page_number INT64, - metadata JSON - ) PRIMARY KEY (id)''' - ]) - database.create().result(120) - - # Create test chunks - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='First document'), - metadata={ - 'source': 'book.pdf', 'page': 10, 'author': 'John' - }), - Chunk( - id='doc2', - embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), - content=Content(text='Second document'), - metadata={ - 'source': 'article.txt', 'page': 5, 'author': 'Jane' - }), - ] - - # Create config with flattened metadata - specs = ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). - with_content_spec().add_metadata_field( - 'source', str, column_name='source').add_metadata_field( - 'page', int, - column_name='page_number').with_metadata_spec().build()) - - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - column_specs=specs, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify data - database = self.spanner_helper.instance.database(self.database_id) - with database.snapshot() as snapshot: - results = snapshot.execute_sql( - f'SELECT id, embedding, content, source, page_number, metadata ' - f'FROM {self.table_name} ORDER BY id') - rows = list(results) - - self.assertEqual(len(rows), 2) - - # Check first row - self.assertEqual(rows[0][0], 'doc1') - self.assertEqual(list(rows[0][1]), [1.0, 2.0, 3.0]) - self.assertEqual(rows[0][2], 'First document') - self.assertEqual(rows[0][3], 'book.pdf') # flattened source - self.assertEqual(rows[0][4], 10) # flattened page_number - - metadata1 = rows[0][5] - self.assertEqual(metadata1['author'], 'John') - - def test_write_minimal_schema(self): - """Test writing with minimal schema (only id and embedding).""" - from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder - - # Create custom database with minimal schema - self.spanner_helper.drop_database(self.database_id) - database = self.spanner_helper.instance.database( - self.database_id, - ddl_statements=[ - f''' - CREATE TABLE {self.table_name} ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>3) - ) PRIMARY KEY (id)''' - ]) - database.create().result(120) - - # Create test chunks - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='First document'), - metadata={'source': 'test'}), - Chunk( - id='doc2', - embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), - content=Content(text='Second document'), - metadata={'source': 'test'}), - ] - - # Create config with minimal schema - specs = ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec().build( - )) - - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - column_specs=specs, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify data - results = self.spanner_helper.read_data(self.database_id) - self.assertEqual(len(results), 2) - self.assertEqual(results[0][0], 'doc1') - self.assertEqual(list(results[0][1]), [1.0, 2.0, 3.0]) - - def test_write_with_converter(self): - """Test writing with custom converter function.""" - from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder - - # Create test chunks with embeddings that need normalization - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[3.0, 4.0, 0.0]), - content=Content(text='First document'), - metadata={'source': 'test'}), - ] - - # Define normalizer - def normalize(vec): - norm = (sum(x**2 for x in vec)**0.5) or 1.0 - return [x / norm for x in vec] - - # Create config with normalized embeddings - specs = ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec( - convert_fn=normalize).with_content_spec().with_metadata_spec(). - build()) - - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - column_specs=specs, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify data - embedding should be normalized - results = self.spanner_helper.read_data(self.database_id) - self.assertEqual(len(results), 1) - - embedding = list(results[0][1]) - # Original was [3.0, 4.0, 0.0], normalized should be [0.6, 0.8, 0.0] - self.assertAlmostEqual(embedding[0], 0.6, places=5) - self.assertAlmostEqual(embedding[1], 0.8, places=5) - self.assertAlmostEqual(embedding[2], 0.0, places=5) - - # Check norm is 1.0 - norm = sum(x**2 for x in embedding)**0.5 - self.assertAlmostEqual(norm, 1.0, places=5) - - def test_write_update_mode(self): - """Test writing with UPDATE mode.""" - # First insert data - chunks_insert = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='Original content'), - metadata={'version': 1}), - ] - - config_insert = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - write_mode='INSERT', - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = ( - p - | beam.Create(chunks_insert) - | config_insert.create_write_transform()) - - # Update existing row - chunks_update = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), - content=Content(text='Updated content'), - metadata={'version': 2}), - ] - - config_update = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - write_mode='UPDATE', - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = ( - p - | beam.Create(chunks_update) - | config_update.create_write_transform()) - - # Verify update succeeded - results = self.spanner_helper.read_data(self.database_id) - self.assertEqual(len(results), 1) - self.assertEqual(results[0][0], 'doc1') - self.assertEqual(list(results[0][1]), [4.0, 5.0, 6.0]) - self.assertEqual(results[0][2], 'Updated content') - - metadata = results[0][3] - self.assertEqual(metadata['version'], 2) - - def test_write_custom_column(self): - """Test writing with custom computed column.""" - from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder - - # Create custom database with computed column - self.spanner_helper.drop_database(self.database_id) - database = self.spanner_helper.instance.database( - self.database_id, - ddl_statements=[ - f''' - CREATE TABLE {self.table_name} ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>3), - content STRING(MAX), - word_count INT64, - metadata JSON - ) PRIMARY KEY (id)''' - ]) - database.create().result(120) - - # Create test chunks - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='Hello world test'), - metadata={}), - Chunk( - id='doc2', - embedding=Embedding(dense_embedding=[4.0, 5.0, 6.0]), - content=Content(text='This is a longer test document'), - metadata={}), - ] - - # Create config with custom word_count column - specs = ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec( - ).with_content_spec().add_column( - column_name='word_count', - python_type=int, - value_fn=lambda chunk: len(chunk.content.text.split())). - with_metadata_spec().build()) - - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - column_specs=specs, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify data - database = self.spanner_helper.instance.database(self.database_id) - with database.snapshot() as snapshot: - results = snapshot.execute_sql( - f'SELECT id, word_count FROM {self.table_name} ORDER BY id') - rows = list(results) - - self.assertEqual(len(rows), 2) - self.assertEqual(rows[0][1], 3) # "Hello world test" = 3 words - self.assertEqual(rows[1][1], 6) # 6 words - - def test_write_with_timestamp(self): - """Test writing with timestamp columns.""" - from apache_beam.ml.rag.ingestion.spanner import SpannerColumnSpecsBuilder - - # Create database with timestamp column - self.spanner_helper.drop_database(self.database_id) - database = self.spanner_helper.instance.database( - self.database_id, - ddl_statements=[ - f''' - CREATE TABLE {self.table_name} ( - id STRING(1024) NOT NULL, - embedding ARRAY(vector_length=>3), - content STRING(MAX), - created_at TIMESTAMP, - metadata JSON - ) PRIMARY KEY (id)''' - ]) - database.create().result(120) - - # Create chunks with timestamp - timestamp_str = "2025-10-28T09:45:00.123456Z" - chunks = [ - Chunk( - id='doc1', - embedding=Embedding(dense_embedding=[1.0, 2.0, 3.0]), - content=Content(text='Document with timestamp'), - metadata={'created_at': timestamp_str}), - ] - - # Create config with timestamp field - specs = ( - SpannerColumnSpecsBuilder().with_id_spec().with_embedding_spec(). - with_content_spec().add_metadata_field( - 'created_at', str, - column_name='created_at').with_metadata_spec().build()) - - config = SpannerVectorWriterConfig( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - table_name=self.table_name, - column_specs=specs, - emulator_host=self.spanner_helper.get_emulator_host(), - ) - - # Write chunks - with TestPipeline() as p: - p.not_use_test_runner_api = True - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - # Verify timestamp was written - database = self.spanner_helper.instance.database(self.database_id) - with database.snapshot() as snapshot: - results = snapshot.execute_sql( - f'SELECT id, created_at FROM {self.table_name}') - rows = list(results) - - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0][0], 'doc1') - # Timestamp is returned as datetime object by Spanner client - self.assertIsNotNone(rows[0][1]) - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - unittest.main()