diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md index 3ee4db328979..074cf768ce6f 100644 --- a/docs/docs/pypaimon/ray-data.md +++ b/docs/docs/pypaimon/ray-data.md @@ -277,3 +277,48 @@ write_builder = table.new_batch_write_builder().overwrite() # overwrite partition 'dt=2024-01-01' write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'}) ``` + +## Merge Into + +`merge_into` updates (and optionally inserts) rows of a **data-evolution** table +from a source, like SQL `MERGE INTO`. Matched rows are updated in place by +`_ROW_ID`; only the touched columns are rewritten. Requires `ray >= 2.50` and a +target table with `'data-evolution.enabled'` and `'row-tracking.enabled'` set. + +```python +from pypaimon.ray import merge_into, WhenMatched, WhenNotMatched + +metrics = merge_into( + target="database_name.table_name", + source=ray_dataset, # ray.data.Dataset / pa.Table / pandas / table-name str + catalog_options={"warehouse": "/path/to/warehouse"}, + on=["id"], # or {"target_col": "source_col"} for renamed keys + when_matched=[WhenMatched(update={"score": "s.score"})], # or update="*" + when_not_matched=[WhenNotMatched(insert="*")], # optional +) +print(metrics) # {"num_updated": 3, "num_inserted": 2} +``` + +- `update` / `insert`: `"*"` (all columns from source), or a dict mapping target + columns to `"s."`, `"t."`, or a literal. +- `condition` (optional): a string expression over `s.` / `t.` using + `> < >= <= == != and or not`; only referenced columns are read. Example: + `WhenMatched(update={"score": "s.score"}, condition="s.version > t.version")`. + +**Parameters:** +- `on`: key columns, or `{target_col: source_col}` for renamed keys. +- `num_partitions`: shuffle parallelism for the join and the write; defaults to + `max(16, cluster_cpus * 2)`, raise it for large merges. +- `ray_remote_args`, `concurrency`: scheduling for the insert path. +- `allow_multiple_matches`: if `False` (default), a target row matched by + multiple source rows raises; `True` keeps the first match. + +**Returns:** `{"num_updated", "num_inserted"}`. + +**Notes:** +- Blob columns cannot be updated and are never read into the join. +- Updating a globally-indexed column raises by default; set + `'global-index.column-update-action' = 'DROP_PARTITION_INDEX'` to drop the + affected index instead (rebuild afterwards). +- Cost scales with how many data files the updated rows touch; scattered updates + over a large table rewrite the updated column of many files. diff --git a/paimon-python/dev/requirements-dev.txt b/paimon-python/dev/requirements-dev.txt index d4e9a0645b17..9ef88817f726 100644 --- a/paimon-python/dev/requirements-dev.txt +++ b/paimon-python/dev/requirements-dev.txt @@ -21,8 +21,9 @@ duckdb==1.3.2 flake8==4.0.1 pytest~=7.0 -# Ray: 2.48+ has no wheel for Python 3.8; use 2.10.0 on 3.8, 2.48.0 on 3.9+ -ray>=2.10.0 +# merge_into needs Dataset.join (added in Ray 2.50). Python 3.8 has no 2.50 wheel. +ray>=2.10.0; python_version < "3.9" +ray>=2.50.0; python_version >= "3.9" requests parameterized # Vortex 0.71.0 regresses native predicate pushdown on single-row files. diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py index 2d140b9539b6..06b9b7e86967 100644 --- a/paimon-python/pypaimon/common/options/core_options.py +++ b/paimon-python/pypaimon/common/options/core_options.py @@ -398,6 +398,16 @@ class CoreOptions: ) ) + GLOBAL_INDEX_COLUMN_UPDATE_ACTION: ConfigOption[str] = ( + ConfigOptions.key("global-index.column-update-action") + .string_type() + .default_value("THROW_ERROR") + .with_description( + "Defines the action to take when an update modifies columns that " + "are covered by a global index. THROW_ERROR or DROP_PARTITION_INDEX." + ) + ) + LOCAL_CACHE_ENABLED: ConfigOption[bool] = ( ConfigOptions.key("local-cache.enabled") .boolean_type() @@ -652,6 +662,9 @@ def row_tracking_enabled(self, default=None): def data_evolution_enabled(self, default=None): return self.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, default) + def global_index_column_update_action(self, default=None): + return self.options.get(CoreOptions.GLOBAL_INDEX_COLUMN_UPDATE_ACTION, default) + def deletion_vectors_enabled(self, default=None): return self.options.get(CoreOptions.DELETION_VECTORS_ENABLED, default) diff --git a/paimon-python/pypaimon/manifest/index_manifest_entry.py b/paimon-python/pypaimon/manifest/index_manifest_entry.py index 7a5e7d1a4f53..9ec5f103dba5 100644 --- a/paimon-python/pypaimon/manifest/index_manifest_entry.py +++ b/paimon-python/pypaimon/manifest/index_manifest_entry.py @@ -41,22 +41,3 @@ def __eq__(self, other): def __hash__(self): return hash((self.kind, tuple(self.partition.values), self.bucket, self.index_file)) - - -INDEX_MANIFEST_ENTRY = { - "type": "record", - "name": "IndexManifestEntry", - "fields": [ - {"name": "_VERSION", "type": "int"}, - {"name": "_KIND", "type": "byte"}, - {"name": "_PARTITION", "type": "bytes"}, - {"name": "_BUCKET", "type": "int"}, - {"name": "_INDEX_TYPE", "type": "string"}, - {"name": "_FILE_NAME", "type": "string"}, - {"name": "_FILE_SIZE", "type": "long"}, - {"name": "_ROW_COUNT", "type": "long"}, - {"name": "_DELETIONS_VECTORS_RANGES", "type": {"type": "array", "elementType": "DeletionVectorMeta"}}, - {"name": "_EXTERNAL_PATH", "type": ["null", "string"]}, - {"name": "_GLOBAL_INDEX", "type": "GlobalIndexMeta"} - ] -} diff --git a/paimon-python/pypaimon/manifest/index_manifest_file.py b/paimon-python/pypaimon/manifest/index_manifest_file.py index 4e65e95e0cb1..5312b0975255 100644 --- a/paimon-python/pypaimon/manifest/index_manifest_file.py +++ b/paimon-python/pypaimon/manifest/index_manifest_file.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import uuid from io import BytesIO from typing import List, Optional @@ -24,11 +25,60 @@ from pypaimon.index.deletion_vector_meta import DeletionVectorMeta from pypaimon.index.index_file_meta import IndexFileMeta from pypaimon.manifest.index_manifest_entry import IndexManifestEntry -from pypaimon.table.row.generic_row import GenericRowDeserializer +from pypaimon.table.row.generic_row import (GenericRowDeserializer, + GenericRowSerializer) +from pypaimon.utils.file_store_path_factory import FileStorePathFactory + +_DELETION_VECTOR_META_SCHEMA = { + "type": "record", + "name": "DeletionVectorMeta", + "fields": [ + {"name": "f0", "type": "string"}, + {"name": "f1", "type": "long"}, + {"name": "f2", "type": "int"}, + {"name": "_CARDINALITY", "type": ["null", "long"], "default": None}, + ], +} + +_GLOBAL_INDEX_META_SCHEMA = { + "type": "record", + "name": "GlobalIndexMeta", + "fields": [ + {"name": "_ROW_RANGE_START", "type": "long"}, + {"name": "_ROW_RANGE_END", "type": "long"}, + {"name": "_INDEX_FIELD_ID", "type": "int"}, + {"name": "_EXTRA_FIELD_IDS", + "type": ["null", {"type": "array", "items": "int"}], "default": None}, + {"name": "_INDEX_META", "type": ["null", "bytes"], "default": None}, + ], +} + +INDEX_MANIFEST_ENTRY_SCHEMA = { + "type": "record", + "name": "IndexManifestEntry", + "fields": [ + {"name": "_VERSION", "type": "int"}, + {"name": "_KIND", "type": "int"}, + {"name": "_PARTITION", "type": "bytes"}, + {"name": "_BUCKET", "type": "int"}, + {"name": "_INDEX_TYPE", "type": "string"}, + {"name": "_FILE_NAME", "type": "string"}, + {"name": "_FILE_SIZE", "type": "long"}, + {"name": "_ROW_COUNT", "type": "long"}, + {"name": "_DELETIONS_VECTORS_RANGES", + "type": ["null", {"type": "array", "items": _DELETION_VECTOR_META_SCHEMA}], + "default": None}, + {"name": "_EXTERNAL_PATH", "type": ["null", "string"], "default": None}, + {"name": "_GLOBAL_INDEX", + "type": ["null", _GLOBAL_INDEX_META_SCHEMA], "default": None}, + ], +} + +_INDEX_ENTRY_VERSION = 1 class IndexManifestFile: - """Index manifest file reader for reading index manifest entries.""" + """Index manifest file reader/writer for index manifest entries.""" DELETION_VECTORS_INDEX = "DELETION_VECTORS" @@ -172,5 +222,73 @@ def _parse_global_index_meta(self, global_index_record) -> Optional[GlobalIndexM row_range_start=global_index_record.get('_ROW_RANGE_START', 0), row_range_end=global_index_record.get('_ROW_RANGE_END', 0), index_field_id=global_index_record.get('_INDEX_FIELD_ID', 0), + extra_field_ids=global_index_record.get('_EXTRA_FIELD_IDS'), index_meta=global_index_record.get('_INDEX_META') ) + + def combine( + self, + previous_name: Optional[str], + deletes: List[IndexManifestEntry], + ) -> Optional[str]: + """Apply DELETE entries to the previous index manifest and write a new one. + + Mirrors Java GlobalIndexCombiner: the stored manifest only holds ADD + entries; deleting means dropping the entries whose index file name + appears in *deletes*. Returns the new manifest file name, or + *previous_name* unchanged when there is nothing to delete. + """ + if not deletes: + return previous_name + previous = self.read(previous_name) if previous_name else [] + delete_names = {e.index_file.file_name for e in deletes} + survivors = [e for e in previous if e.index_file.file_name not in delete_names] + return self.write(survivors) + + def write(self, entries: List[IndexManifestEntry]) -> str: + """Serialize *entries* to a new Avro index manifest, return its name.""" + file_name = f"{FileStorePathFactory.INDEX_MANIFEST_PREFIX}{uuid.uuid4()}" + path = f"{self.manifest_path}/{file_name}" + records = [self._to_avro_record(e) for e in entries] + try: + buffer = BytesIO() + fastavro.writer(buffer, INDEX_MANIFEST_ENTRY_SCHEMA, records) + with self.file_io.new_output_stream(path) as output_stream: + output_stream.write(buffer.getvalue()) + except Exception as e: + self.file_io.delete_quietly(path) + raise RuntimeError(f"Failed to write index manifest file: {e}") from e + return file_name + + def _to_avro_record(self, entry: IndexManifestEntry) -> dict: + index_file = entry.index_file + dv_ranges = None + if index_file.dv_ranges: + dv_ranges = [ + {"f0": dv.data_file_name, "f1": dv.offset, "f2": dv.length, + "_CARDINALITY": dv.cardinality} + for dv in index_file.dv_ranges.values() + ] + global_index = None + if index_file.global_index_meta is not None: + gim = index_file.global_index_meta + global_index = { + "_ROW_RANGE_START": gim.row_range_start, + "_ROW_RANGE_END": gim.row_range_end, + "_INDEX_FIELD_ID": gim.index_field_id, + "_EXTRA_FIELD_IDS": gim.extra_field_ids, + "_INDEX_META": gim.index_meta, + } + return { + "_VERSION": _INDEX_ENTRY_VERSION, + "_KIND": entry.kind, + "_PARTITION": GenericRowSerializer.to_bytes(entry.partition), + "_BUCKET": entry.bucket, + "_INDEX_TYPE": index_file.index_type, + "_FILE_NAME": index_file.file_name, + "_FILE_SIZE": index_file.file_size, + "_ROW_COUNT": index_file.row_count, + "_DELETIONS_VECTORS_RANGES": dv_ranges, + "_EXTERNAL_PATH": index_file.external_path, + "_GLOBAL_INDEX": global_index, + } diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py index f36eb0253dd8..9161f3cbb3b7 100644 --- a/paimon-python/pypaimon/ray/__init__.py +++ b/paimon-python/pypaimon/ray/__init__.py @@ -16,5 +16,16 @@ # under the License. from pypaimon.ray.ray_paimon import read_paimon, write_paimon +from pypaimon.ray.data_evolution_merge_into import ( + WhenMatched, + WhenNotMatched, + merge_into, +) -__all__ = ["read_paimon", "write_paimon"] +__all__ = [ + "read_paimon", + "write_paimon", + "merge_into", + "WhenMatched", + "WhenNotMatched", +] diff --git a/paimon-python/pypaimon/ray/condition_expr.py b/paimon-python/pypaimon/ray/condition_expr.py new file mode 100644 index 000000000000..c20c3c111fb8 --- /dev/null +++ b/paimon-python/pypaimon/ray/condition_expr.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import operator +from typing import Mapping, Set + +_PREFIXES = ("s", "t") + +_COMPARE_OPS = { + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, +} + +_ALLOWED_NODES = ( + ast.BoolOp, ast.And, ast.Or, + ast.UnaryOp, ast.Not, ast.USub, ast.UAdd, + ast.Compare, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, + ast.Constant, ast.Attribute, ast.Name, ast.Load, +) + + +class ConditionExpr: + """A parsed merge condition over the joined row. + + Columns are referenced as ``s.col`` / ``t.col``; the evaluator reads them + from a combined ``{"s.col": ..., "t.col": ...}`` mapping. Only comparisons, + boolean and/or/not, and literals are supported, so the expression is safe to + evaluate (no ``eval``) and its referenced columns can be extracted statically. + """ + + def __init__(self, source: str, body: ast.AST): + self.source = source + self._body = body + + def eval(self, combined: Mapping) -> bool: + return bool(_eval(self._body, combined)) + + def target_columns(self) -> Set[str]: + return self._columns("t") + + def source_columns(self) -> Set[str]: + return self._columns("s") + + def _columns(self, prefix: str) -> Set[str]: + cols = set() + for node in ast.walk(self._body): + if (isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id == prefix): + cols.add(node.attr) + return cols + + +def parse(source: str) -> ConditionExpr: + try: + tree = ast.parse(source, mode="eval") + except SyntaxError as e: + raise ValueError(f"Invalid merge condition {source!r}: {e}") + for node in ast.walk(tree): + if isinstance(node, ast.Expression): + continue + if not isinstance(node, _ALLOWED_NODES): + raise ValueError( + f"Unsupported syntax in merge condition {source!r}: " + f"{type(node).__name__}. Only comparisons of s./t. columns and " + f"literals combined with and/or/not are allowed." + ) + if isinstance(node, ast.Attribute): + if not (isinstance(node.value, ast.Name) and node.value.id in _PREFIXES): + raise ValueError( + f"Column reference in merge condition {source!r} must be " + f"'s.' or 't.'." + ) + if isinstance(node, ast.Name) and node.id not in _PREFIXES: + raise ValueError( + f"Unknown name {node.id!r} in merge condition {source!r}; " + f"only 's' and 't' are allowed." + ) + return ConditionExpr(source, tree.body) + + +def _eval(node, combined): + if isinstance(node, ast.BoolOp): + values = (_eval(v, combined) for v in node.values) + if isinstance(node.op, ast.And): + return all(values) + return any(values) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return not _eval(node.operand, combined) + if isinstance(node, ast.Compare): + left = _operand(node.left, combined) + ok = True + for op, comparator in zip(node.ops, node.comparators): + right = _operand(comparator, combined) + ok = ok and _apply(op, left, right) + left = right + return ok + return _operand(node, combined) + + +def _operand(node, combined): + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Attribute): + return combined.get(f"{node.value.id}.{node.attr}") + if isinstance(node, ast.UnaryOp): + value = _operand(node.operand, combined) + return -value if isinstance(node.op, ast.USub) else value + return _eval(node, combined) + + +def _apply(op, left, right): + if isinstance(op, ast.Eq): + return left == right + if isinstance(op, ast.NotEq): + return left != right + if left is None or right is None: + return False + return _COMPARE_OPS[type(op)](left, right) diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py new file mode 100644 index 000000000000..2876a0a5c838 --- /dev/null +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -0,0 +1,1056 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""MERGE INTO ... USING ... for Paimon data-evolution tables via Ray Datasets.""" + +from dataclasses import dataclass +from typing import ( + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +import pyarrow as pa + +from pypaimon.ray.condition_expr import ConditionExpr +from pypaimon.ray.condition_expr import parse as parse_condition + +SetSpec = Union[str, Dict[str, Any]] +Condition = str +OnSpec = Union[Sequence[str], Mapping[str, str]] + + +@dataclass +class WhenMatched: + update: SetSpec + condition: Optional[Condition] = None + + +@dataclass +class WhenNotMatched: + insert: SetSpec + condition: Optional[Condition] = None + + +@dataclass +class _NormalizedClause: + spec: Dict[str, Any] + condition: Optional[ConditionExpr] + + +def merge_into( + target: str, + source: Any, + catalog_options: Dict[str, str], + *, + on: OnSpec, + when_matched: Sequence[WhenMatched] = (), + when_not_matched: Sequence[WhenNotMatched] = (), + num_partitions: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, + concurrency: Optional[int] = None, + allow_multiple_matches: bool = False, +) -> Dict[str, int]: + _require_ray_join() + num_partitions = _resolve_num_partitions(num_partitions) + when_matched = list(when_matched) + when_not_matched = list(when_not_matched) + if not when_matched and not when_not_matched: + raise ValueError( + "At least one of when_matched or when_not_matched must be non-empty." + ) + + target_on_cols, source_on_cols = _normalize_on(on) + + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + table = catalog.get_table(target) + if not table.options.data_evolution_enabled(): + raise ValueError( + f"merge_into requires 'data-evolution.enabled' = 'true' on '{target}'." + ) + if not table.options.row_tracking_enabled(): + raise ValueError( + f"merge_into requires 'row-tracking.enabled' = 'true' on '{target}'." + ) + + target_field_names = list(table.field_names) + on_map = dict(zip(target_on_cols, source_on_cols)) + matched_specs = [ + _NormalizedClause( + spec=_normalize_set_spec(c.update, target_field_names, on_map), + condition=parse_condition(c.condition) if c.condition else None, + ) + for c in when_matched + ] + not_matched_specs = [ + _NormalizedClause( + spec=_normalize_set_spec(c.insert, target_field_names, on_map), + condition=parse_condition(c.condition) if c.condition else None, + ) + for c in when_not_matched + ] + + update_cols: set = set() + for clause in matched_specs: + update_cols.update(clause.spec.keys()) + _reject_blob_updates(table, update_cols) + + source_ds = _normalize_source(source, catalog_options) + _validate_source_on_cols(source_ds, source_on_cols) + + base_snapshot = table.snapshot_manager().get_latest_snapshot() + + global_index_action = ( + table.options.global_index_column_update_action() + or GLOBAL_INDEX_ACTION_THROW_ERROR + ) + + from pypaimon.schema.data_types import PyarrowFieldParser + + target_pa_schema = PyarrowFieldParser.from_paimon_schema( + table.table_schema.fields + ) + + update_ds = None + insert_ds = None + update_cols_union: List[str] = [] + + # With both clauses on a non-empty target, matched and not-matched routing + # share the same source/target equi-join. Build them from one materialized + # LEFT_OUTER join instead of reading and shuffling the target table twice. + if matched_specs and not_matched_specs and base_snapshot is not None: + update_cols_union = _union_update_cols(matched_specs) + update_ds, insert_ds = _build_unified_both( + target_identifier=target, + source_ds=source_ds, + target_on=target_on_cols, + source_on=source_on_cols, + matched_clauses=matched_specs, + not_matched_clauses=not_matched_specs, + target_field_names=target_field_names, + target_pa_schema=target_pa_schema, + update_cols=update_cols_union, + catalog_options=catalog_options, + num_partitions=num_partitions, + ) + else: + # Empty target → no rows can match; matched UPDATE is a no-op. + if matched_specs and base_snapshot is not None: + update_cols_union = _union_update_cols(matched_specs) + update_ds = _build_matched_update_ds( + target_identifier=target, + source_ds=source_ds, + target_on=target_on_cols, + source_on=source_on_cols, + clauses=matched_specs, + target_field_names=target_field_names, + target_pa_schema=target_pa_schema, + update_cols=update_cols_union, + catalog_options=catalog_options, + num_partitions=num_partitions, + ) + + if not_matched_specs: + # Empty target: nothing can match, so every source row inserts. + # Skip all joins (ray's hash join crashes on empty partitions). + insert_ds = _build_not_matched_insert_ds( + target_identifier=target, + source_ds=source_ds, + target_on=target_on_cols, + source_on=source_on_cols, + clauses=not_matched_specs, + target_field_names=target_field_names, + target_pa_schema=target_pa_schema, + catalog_options=catalog_options, + num_partitions=num_partitions, + target_empty=base_snapshot is None, + ) + + update_msgs: list = [] + num_updated = 0 + if update_ds is not None: + update_msgs, num_updated = _distributed_update_apply( + update_ds, + table, + update_cols_union, + num_partitions=num_partitions, + ray_remote_args=ray_remote_args, + allow_multiple_matches=allow_multiple_matches, + ) + + all_msgs: list = list(update_msgs) + num_inserted = 0 + if insert_ds is not None: + insert_msgs = _distributed_write_collect_msgs( + insert_ds, table, ray_remote_args=ray_remote_args, concurrency=concurrency + ) + num_inserted = sum(f.row_count for m in insert_msgs for f in m.new_files) + all_msgs.extend(insert_msgs) + # Mirror Spark's checkUpdateResult: scope the global-index action to the + # partitions the update actually wrote and the updated indexed columns. + all_msgs.extend( + _apply_global_index_update_action( + table, base_snapshot, update_cols_union, update_msgs, global_index_action + ) + ) + if all_msgs: + wb = table.new_batch_write_builder() + tc = wb.new_commit() + tc.commit(all_msgs) + tc.close() + + return {"num_updated": num_updated, "num_inserted": num_inserted} + + +def _normalize_on(on: OnSpec) -> Tuple[List[str], List[str]]: + if isinstance(on, Mapping): + target_cols = list(on.keys()) + source_cols = list(on.values()) + else: + target_cols = list(on) + source_cols = list(on) + if not target_cols: + raise ValueError("'on' must be non-empty.") + return target_cols, source_cols + + +def _build_matched_update_ds( + *, + target_identifier: str, + source_ds, + target_on: Sequence[str], + source_on: Sequence[str], + clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + catalog_options: Dict[str, str], + num_partitions: int, +): + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.table.special_fields import SpecialFields + + row_id_name = SpecialFields.ROW_ID.name + needed_cols = _resolve_target_projection( + clauses, target_on, update_cols, target_field_names, + ) + projection = [row_id_name] + [c for c in needed_cols if c != row_id_name] + + target_ds = read_paimon(target_identifier, catalog_options, projection=projection) + update_schema = _build_update_schema(target_pa_schema, update_cols, row_id_name) + + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in target_ds.schema().names} + ) + source_schema = source_ds.schema() + source_cols = list(source_schema.names) if source_schema is not None else list(source_on) + source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols}) + + joined = target_renamed.join( + source_renamed, + join_type="inner", + num_partitions=num_partitions, + on=tuple(f"t.{c}" for c in target_on), + right_on=tuple(f"s.{c}" for c in source_on), + ) + + captured_clauses = clauses + captured_update_cols = list(update_cols) + captured_field_names = list(target_field_names) + captured_row_id_name = row_id_name + captured_on_pairs = list(zip(source_on, target_on)) + captured_schema = update_schema + + if _clauses_use_vector_fast_path(clauses): + first_spec = clauses[0].spec + + def _fast(batch: pa.Table) -> pa.Table: + return _vectorized_matched_transform( + batch, + first_spec, + captured_on_pairs, + captured_update_cols, + captured_row_id_name, + captured_schema, + ) + + return joined.map_batches(_fast, batch_format="pyarrow") + + def _transform(batch: pa.Table) -> pa.Table: + return _apply_matched_transform( + batch, + captured_clauses, + captured_on_pairs, + captured_update_cols, + captured_field_names, + captured_row_id_name, + captured_schema, + ) + + return joined.map_batches(_transform, batch_format="pyarrow") + + +def _apply_matched_transform( + batch: pa.Table, + clauses: List[_NormalizedClause], + on_pairs: Sequence[Tuple[str, str]], + update_cols: Sequence[str], + field_names: Sequence[str], + row_id_name: str, + update_schema: pa.Schema, +) -> pa.Table: + rows = batch.to_pylist() + out_row_ids: list = [] + out_cols: Dict[str, list] = {c: [] for c in update_cols} + for row in rows: + s_row = {k[2:]: v for k, v in row.items() if k.startswith("s.")} + t_row = {k[2:]: v for k, v in row.items() if k.startswith("t.")} + for s_key, t_key in on_pairs: + if s_key not in s_row and t_key in t_row: + s_row[s_key] = t_row[t_key] + combined = _prefixed(s_row, t_row) + for clause in clauses: + if clause.condition is not None and not clause.condition.eval(combined): + continue + new_values = _apply_set(clause.spec, s_row, t_row, field_names) + out_row_ids.append(t_row[row_id_name]) + for col in update_cols: + out_cols[col].append(new_values.get(col, t_row.get(col))) + break + return pa.Table.from_pydict( + {row_id_name: out_row_ids, **out_cols}, + schema=update_schema, + ) + + +def _build_update_schema( + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + row_id_name: str, +) -> pa.Schema: + return pa.schema( + [pa.field(row_id_name, pa.int64(), nullable=False)] + + [target_pa_schema.field(col) for col in update_cols] + ) + + +def _distributed_update_apply( + update_ds, + table, + write_update_cols: Sequence[str], + *, + num_partitions: int, + ray_remote_args: Optional[Dict[str, Any]] = None, + allow_multiple_matches: bool = False, +) -> Tuple[list, int]: + import numpy as np + import pickle + import uuid + + from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER + from pypaimon.table.special_fields import SpecialFields + from pypaimon.write.table_update_by_row_id import TableUpdateByRowId + + row_id_name = SpecialFields.ROW_ID.name + cols = list(write_update_cols) + + for col in cols: + if col not in table.field_names: + raise ValueError(f"Column '{col}' is not in target table schema.") + + planner = TableUpdateByRowId( + table, + "_merge_into_planner_" + uuid.uuid4().hex[:8], + BATCH_COMMIT_IDENTIFIER, + ) + sorted_first_row_ids = list(planner.first_row_ids) + if not sorted_first_row_ids: + return [], 0 + + # Broadcast the file-info snapshot to every worker so they skip the + # per-task manifest scan and observe a single consistent target view. + precomputed_info = ( + planner.snapshot_id, + planner.first_row_ids, + planner._first_row_id_index, + planner.total_row_count, + ) + + frid_col = "_FIRST_ROW_ID" + captured_sorted = sorted_first_row_ids + captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64) + first = captured_sorted_arr[0] + captured_precomputed = precomputed_info + total_row_count = planner.total_row_count + + def _assign_frid(batch: pa.Table) -> pa.Table: + if batch.num_rows == 0: + return batch.append_column(frid_col, pa.array([], type=pa.int64())) + rid_col = batch.column(row_id_name) + if rid_col.null_count: + raise ValueError( + "_ROW_ID is null; planner snapshot is stale " + "or matched rows come from a different table." + ) + rids = rid_col.to_numpy(zero_copy_only=False) + # Out-of-range _ROW_IDs would silently map via searchsorted wrap-around. + out_of_range = (rids < first) | (rids >= total_row_count) + if out_of_range.any(): + bad = rids[out_of_range][0] + raise ValueError( + f"_ROW_ID {bad} is out of valid range " + f"[{first}, {total_row_count}); planner snapshot is stale " + f"or matched rows come from a different table." + ) + idx = np.searchsorted(captured_sorted_arr, rids, side="right") - 1 + frids = captured_sorted_arr[idx] + return batch.append_column(frid_col, pa.array(frids, type=pa.int64())) + + with_frid = update_ds.map_batches(_assign_frid, batch_format="pyarrow") + + captured_table = table + captured_cols = cols + + def _apply_group(group: pa.Table) -> pa.Table: + if group.num_rows == 0: + return pa.Table.from_pydict({ + "msgs_blob": pa.array([], type=pa.binary()), + "n_updated": pa.array([], type=pa.int64()), + }) + + # One target _ROW_ID matched by several source rows. Default: refuse + # (the winning value is otherwise undefined, as in Spark DE's + # checkCardinality=false). Opt-in keeps the first match deterministically. + group_row_ids = group.column(row_id_name).to_pylist() + if len(set(group_row_ids)) != len(group_row_ids): + if not allow_multiple_matches: + raise ValueError( + "MERGE matched multiple source rows to the same target " + "_ROW_ID. Deduplicate the source, or pass " + "allow_multiple_matches=True to keep the first match." + ) + seen: set = set() + keep_indices: list = [] + for i, rid in enumerate(group_row_ids): + if rid not in seen: + seen.add(rid) + keep_indices.append(i) + group = group.take(pa.array(keep_indices, type=pa.int64())) + + for_update = group.drop_columns([frid_col]) + worker = TableUpdateByRowId( + captured_table, + "_merge_into_shard_" + uuid.uuid4().hex[:8], + BATCH_COMMIT_IDENTIFIER, + precomputed_files_info=captured_precomputed, + ) + msgs = worker.update_columns(for_update, list(captured_cols)) + return pa.Table.from_pydict({ + "msgs_blob": [pickle.dumps(msgs)], + "n_updated": pa.array([for_update.num_rows], type=pa.int64()), + }) + + # One group per target data file (distinct _FIRST_ROW_ID). Drive the write + # shuffle with the same num_partitions knob as the join (Spark's single + # shuffle.partitions), bounded by the file count so small merges don't spawn + # empty reduce tasks and large ones scale past a fixed cap. + group_partitions = max(1, min(len(captured_sorted), num_partitions)) + msgs_ds = with_frid.groupby(frid_col, num_partitions=group_partitions).map_groups( + _apply_group, batch_format="pyarrow" + ) + + all_msgs: list = [] + num_updated = 0 + for batch in msgs_ds.iter_batches(batch_format="pyarrow"): + for blob in batch.column("msgs_blob").to_pylist(): + all_msgs.extend(pickle.loads(blob)) + for n in batch.column("n_updated").to_pylist(): + num_updated += n + return all_msgs, num_updated + + +GLOBAL_INDEX_ACTION_THROW_ERROR = "THROW_ERROR" +GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX = "DROP_PARTITION_INDEX" + + +def _resolve_num_partitions(num_partitions: Optional[int]) -> int: + if num_partitions is not None: + return num_partitions + try: + import ray + + cpus = ray.cluster_resources().get("CPU", 16) + return max(16, int(cpus) * 2) + except Exception: + return 16 + + +def _clauses_use_vector_fast_path( + clauses: List[_NormalizedClause], +) -> bool: + if not clauses: + return False + for c in clauses: + if c.condition is not None: + return False + for v in c.spec.values(): + if callable(v): + return False + return True + + +def _vectorized_matched_transform( + batch: pa.Table, + spec: Dict[str, Any], + on_pairs: Sequence[Tuple[str, str]], + update_cols: Sequence[str], + row_id_name: str, + update_schema: pa.Schema, +) -> pa.Table: + available = set(batch.schema.names) + arrays: list = [batch.column(f"t.{row_id_name}")] + for col in update_cols: + out_type = update_schema.field(col).type + if col in spec: + arrays.append(_resolve_spec_array(spec[col], batch, available, on_pairs, out_type)) + else: + arrays.append(batch.column(f"t.{col}")) + return pa.Table.from_arrays(arrays, schema=update_schema) + + +def _vectorized_insert_transform( + batch: pa.Table, + spec: Dict[str, Any], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, +) -> pa.Table: + available = set(batch.schema.names) + arrays: list = [] + for col in target_field_names: + out_type = target_pa_schema.field(col).type + if col in spec: + arrays.append(_resolve_spec_array(spec[col], batch, available, (), out_type)) + else: + arrays.append(pa.nulls(batch.num_rows, type=out_type)) + return pa.Table.from_arrays(arrays, schema=target_pa_schema) + + +def _resolve_spec_array( + val: Any, + batch: pa.Table, + available: set, + on_pairs: Sequence[Tuple[str, str]], + out_type: pa.DataType, +): + if isinstance(val, str) and val.startswith("s."): + ref = val[2:] + if f"s.{ref}" in available: + return batch.column(f"s.{ref}") + # Equi-join drops the right-side join key; fall back to target's value. + for sk, tk in on_pairs: + if sk == ref and f"t.{tk}" in available: + return batch.column(f"t.{tk}") + return pa.nulls(batch.num_rows, type=out_type) + if isinstance(val, str) and val.startswith("t."): + ref = val[2:] + col_name = f"t.{ref}" + return batch.column(col_name) if col_name in available else pa.nulls( + batch.num_rows, type=out_type + ) + return pa.array([val] * batch.num_rows, type=out_type) + + +def _build_not_matched_insert_ds( + *, + target_identifier: str, + source_ds, + target_on: Sequence[str], + source_on: Sequence[str], + clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + catalog_options: Dict[str, str], + num_partitions: int, + target_empty: bool = False, +): + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.ray.shuffle import _coerce_large_string_types + + captured_clauses = clauses + captured_field_names = list(target_field_names) + out_schema = target_pa_schema + + source_schema = source_ds.schema() + source_cols = list(source_schema.names) if source_schema is not None else list(source_on) + source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols}) + + if target_empty: + unmatched = source_renamed + else: + target_ds = read_paimon( + target_identifier, catalog_options, projection=list(target_on) + ) + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in target_on} + ) + unmatched = source_renamed.join( + target_renamed, + join_type="left_anti", + num_partitions=num_partitions, + on=tuple(f"s.{c}" for c in source_on), + right_on=tuple(f"t.{c}" for c in target_on), + ) + + if _clauses_use_vector_fast_path(clauses): + first_spec = clauses[0].spec + + def _fast(batch: pa.Table) -> pa.Table: + return _coerce_large_string_types( + _vectorized_insert_transform( + batch, first_spec, captured_field_names, out_schema + ) + ) + + return unmatched.map_batches(_fast, batch_format="pyarrow") + + def _transform(batch: pa.Table) -> pa.Table: + return _apply_insert_transform( + batch, captured_clauses, captured_field_names, out_schema + ) + + return unmatched.map_batches(_transform, batch_format="pyarrow") + + +def _apply_insert_transform( + batch: pa.Table, + clauses: List[_NormalizedClause], + field_names: Sequence[str], + out_schema: pa.Schema, +) -> pa.Table: + from pypaimon.ray.shuffle import _coerce_large_string_types + + rows = batch.to_pylist() + out = [] + for row in rows: + s_row = {k[2:]: v for k, v in row.items() if k.startswith("s.")} + combined = _prefixed(s_row, None) + for clause in clauses: + if clause.condition is not None and not clause.condition.eval(combined): + continue + out.append( + _apply_set( + clause.spec, s_row, None, field_names, null_unspecified=True + ) + ) + break + aligned = [{name: r.get(name) for name in field_names} for r in out] + return _coerce_large_string_types(pa.Table.from_pylist(aligned, schema=out_schema)) + + +def _build_unified_both( + *, + target_identifier: str, + source_ds, + target_on: Sequence[str], + source_on: Sequence[str], + matched_clauses: List[_NormalizedClause], + not_matched_clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + catalog_options: Dict[str, str], + num_partitions: int, +): + import pyarrow.compute as pc + + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.ray.shuffle import _coerce_large_string_types + from pypaimon.table.special_fields import SpecialFields + + row_id_name = SpecialFields.ROW_ID.name + + needed_cols = _resolve_target_projection( + matched_clauses, target_on, update_cols, target_field_names, + ) + projection = [row_id_name] + [c for c in needed_cols if c != row_id_name] + target_ds = read_paimon(target_identifier, catalog_options, projection=projection) + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in target_ds.schema().names} + ) + source_schema = source_ds.schema() + source_cols = list(source_schema.names) if source_schema is not None else list(source_on) + source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols}) + + # One LEFT_OUTER join feeds both branches: rows with a non-null target side + # are matched (UPDATE), null target side means no key match (INSERT). The + # join shuffle is the dominant cost, so materialize once and route both ways + # instead of reading and shuffling the target table twice. + joined = source_renamed.join( + target_renamed, + join_type="left_outer", + num_partitions=num_partitions, + on=tuple(f"s.{c}" for c in source_on), + right_on=tuple(f"t.{c}" for c in target_on), + ).materialize() + + t_row_id_col = f"t.{row_id_name}" + on_pairs = list(zip(source_on, target_on)) + update_schema = _build_update_schema(target_pa_schema, update_cols, row_id_name) + + use_fast_matched = _clauses_use_vector_fast_path(matched_clauses) + first_matched_spec = matched_clauses[0].spec if use_fast_matched else None + m_update_cols = list(update_cols) + m_field_names = list(target_field_names) + + def _matched_batch(batch: pa.Table) -> pa.Table: + sub = batch.filter(pc.is_valid(batch.column(t_row_id_col))) + if use_fast_matched: + return _vectorized_matched_transform( + sub, first_matched_spec, on_pairs, m_update_cols, + row_id_name, update_schema, + ) + return _apply_matched_transform( + sub, matched_clauses, on_pairs, m_update_cols, + m_field_names, row_id_name, update_schema, + ) + + update_ds = joined.map_batches(_matched_batch, batch_format="pyarrow") + + i_field_names = list(target_field_names) + use_fast_insert = _clauses_use_vector_fast_path(not_matched_clauses) + first_insert_spec = not_matched_clauses[0].spec if use_fast_insert else None + + def _insert_batch(batch: pa.Table) -> pa.Table: + sub = batch.filter(pc.is_null(batch.column(t_row_id_col))) + if use_fast_insert: + return _coerce_large_string_types( + _vectorized_insert_transform( + sub, first_insert_spec, i_field_names, target_pa_schema + ) + ) + return _apply_insert_transform( + sub, not_matched_clauses, i_field_names, target_pa_schema + ) + + insert_ds = joined.map_batches(_insert_batch, batch_format="pyarrow") + + return update_ds, insert_ds + + +def _distributed_write_collect_msgs( + insert_ds, + table, + *, + ray_remote_args: Optional[Dict[str, Any]], + concurrency: Optional[int], +) -> list: + from pypaimon.write.ray_datasink import PaimonDatasink + + class _CollectingDatasink(PaimonDatasink): + def __init__(self, t): + super().__init__(t, overwrite=False) + self.collected: list = [] + + def on_write_complete(self, write_result): + if hasattr(write_result, "write_returns"): + write_returns = write_result.write_returns + elif isinstance(write_result, list): + write_returns = write_result + else: + raise TypeError( + f"Unexpected write_result type {type(write_result).__name__}" + ) + self.collected = [ + m + for batch in write_returns + for m in batch + if not m.is_empty() + ] + + sink = _CollectingDatasink(table) + write_kwargs: Dict[str, Any] = {} + if ray_remote_args is not None: + write_kwargs["ray_remote_args"] = ray_remote_args + if concurrency is not None: + write_kwargs["concurrency"] = concurrency + insert_ds.write_datasink(sink, **write_kwargs) + return sink.collected + + +def _apply_global_index_update_action( + table, snapshot, update_cols: Sequence[str], update_msgs, action: str +) -> list: + """Handle updates touching globally-indexed columns, mirroring Spark's + ``checkUpdateResult``. + + Scoped exactly like Spark: only index entries whose partition was written + by the update *and* whose indexed column is among the updated columns are + affected. THROW_ERROR (default) raises; DROP_PARTITION_INDEX drops those + entries (returned as index-delete commit messages, rebuild afterwards). + Like Spark, the INSERT path is left untouched. + """ + if snapshot is None or not update_cols or not update_msgs: + return [] + entries = _scan_global_index_entries(table, snapshot) + if not entries: + return [] + field_by_id = {f.id: f.name for f in table.fields} + update_set = set(update_cols) + affected_partitions = {tuple(m.partition) for m in update_msgs} + affected = [ + e for e in entries + if field_by_id.get(e.index_file.global_index_meta.index_field_id) in update_set + and tuple(e.partition.values) in affected_partitions + ] + if not affected: + return [] + if action == GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX: + return _build_index_delete_msgs(affected) + conflicted = sorted( + {field_by_id.get(e.index_file.global_index_meta.index_field_id) for e in affected} + ) + raise NotImplementedError( + f"MERGE INTO would update columns {conflicted} that have a global " + f"index; not supported (refusing to leave the index stale). Set " + f"'global-index.column-update-action' = 'DROP_PARTITION_INDEX' to drop " + f"the affected index instead." + ) + + +def _build_index_delete_msgs(entries) -> list: + """Group scanned index entries by partition into index-delete messages.""" + from pypaimon.manifest.index_manifest_entry import IndexManifestEntry + from pypaimon.write.commit_message import CommitMessage + + by_partition: Dict[tuple, list] = {} + for e in entries: + key = tuple(e.partition.values) + by_partition.setdefault(key, []).append( + IndexManifestEntry( + kind=1, partition=e.partition, bucket=e.bucket, index_file=e.index_file + ) + ) + return [ + CommitMessage(partition=key, bucket=0, new_files=[], index_files=dels) + for key, dels in by_partition.items() + ] + + +def _scan_global_index_entries(table, snapshot): + from pypaimon.index.index_file_handler import IndexFileHandler + + handler = IndexFileHandler(table=table) + return handler.scan( + snapshot, lambda e: e.index_file.global_index_meta is not None + ) + + +def _require_ray_join() -> None: + """merge_into relies on ``Dataset.join`` (ray>=2.50). Read/sink users on + older ray are unaffected unless they call this, so check only here.""" + import ray + from ray.data import Dataset + + if not hasattr(Dataset, "join"): + raise RuntimeError( + f"merge_into requires ray>=2.50 (Dataset.join); " + f"installed ray is {ray.__version__}." + ) + + +def _reject_blob_updates(table, update_cols: set) -> None: + blob_cols = [ + f.name + for f in table.table_schema.fields + if f.name in update_cols and getattr(f.type, "type", None) == "BLOB" + ] + if blob_cols: + raise NotImplementedError( + f"merge_into cannot update blob columns {blob_cols}; " + f"the row-id rewrite path skips .blob files." + ) + + +def _union_update_cols(clauses: List[_NormalizedClause]) -> List[str]: + seen: List[str] = [] + seen_set: set = set() + for clause in clauses: + for col in clause.spec.keys(): + if col not in seen_set: + seen.append(col) + seen_set.add(col) + return seen + + +def _needed_target_cols( + clauses: List[_NormalizedClause], + on: Sequence[str], + update_cols: Sequence[str], + all_target_cols: Sequence[str], +) -> list: + # Target needs only: join keys, t.col refs, and cols that may fall back + # (not set by every clause). Cols all clauses set from source aren't read. + needed = set(on) + set_by_all = set(update_cols) + for clause in clauses: + for value in clause.spec.values(): + if callable(value): + return list(all_target_cols) + if isinstance(value, str) and value.startswith("t."): + needed.add(value[2:]) + set_by_all &= set(clause.spec.keys()) + needed |= set(update_cols) - set_by_all + return [c for c in all_target_cols if c in needed] + + +def _resolve_target_projection( + clauses: List[_NormalizedClause], + target_on: Sequence[str], + update_cols: Sequence[str], + target_field_names: Sequence[str], +) -> list: + # Precise: SET-side needs plus the target columns each parsed condition + # references. Anything not referenced (e.g. blob) is never read. + needed = set(_needed_target_cols(clauses, target_on, update_cols, target_field_names)) + target_set = set(target_field_names) + for clause in clauses: + if clause.condition is not None: + needed |= clause.condition.target_columns() & target_set + return [c for c in target_field_names if c in needed] + + +def _normalize_set_spec( + spec: SetSpec, + target_field_names: Sequence[str], + on_map: Optional[Mapping[str, str]] = None, +) -> Dict[str, Any]: + on_map = on_map or {} + if isinstance(spec, str): + if spec != "*": + raise ValueError( + f"SET spec strings other than '*' are not supported; got {spec!r}." + ) + # A renamed ON key resolves via the source's ON column, not its own name. + return {col: f"s.{on_map.get(col, col)}" for col in target_field_names} + if not isinstance(spec, dict): + raise ValueError( + f"SET spec must be '*' or a dict, got {type(spec).__name__}." + ) + target_set = set(target_field_names) + for col in spec: + if col not in target_set: + raise ValueError( + f"SET key '{col}' is not a column of the target table " + f"(columns: {list(target_field_names)})." + ) + return dict(spec) + + +def _normalize_source(source: Any, catalog_options: Dict[str, str]): + import ray.data + + if isinstance(source, ray.data.Dataset): + return source + if isinstance(source, str): + from pypaimon.ray.ray_paimon import read_paimon + return read_paimon(source, catalog_options) + if isinstance(source, pa.Table): + return ray.data.from_arrow(source) + try: + import pandas as pd + except ImportError: + pd = None + if pd is not None and isinstance(source, pd.DataFrame): + return ray.data.from_pandas(source) + raise TypeError( + "source must be a ray.data.Dataset, a Paimon table identifier string, " + f"a pyarrow.Table, or a pandas.DataFrame; got {type(source).__name__}." + ) + + +def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None: + schema = source_ds.schema() + if schema is None: + return + names = set(schema.names) + missing = [c for c in on if c not in names] + if missing: + raise ValueError( + f"'on' columns {missing} missing from source schema {list(names)}." + ) + + +def _apply_set( + spec: Dict[str, Any], + s_row: Optional[Dict[str, Any]], + t_row: Optional[Dict[str, Any]], + target_field_names: Sequence[str], + null_unspecified: bool = False, +) -> Dict[str, Any]: + combined = _prefixed(s_row, t_row) + if t_row is not None: + base = t_row + elif s_row is not None and not null_unspecified: + base = s_row + else: + base = {} + out: Dict[str, Any] = {} + for col in target_field_names: + if col in spec: + out[col] = _eval_set_value(spec[col], combined, s_row, t_row) + elif col in base: + out[col] = base[col] + else: + out[col] = None + return out + + +def _prefixed( + s_row: Optional[Dict[str, Any]], t_row: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if s_row is not None: + for k, v in s_row.items(): + out[f"s.{k}"] = v + if t_row is not None: + for k, v in t_row.items(): + out[f"t.{k}"] = v + return out + + +def _eval_set_value( + value: Any, + combined: Mapping[str, Any], + s_row: Optional[Dict[str, Any]], + t_row: Optional[Dict[str, Any]], +) -> Any: + if callable(value): + return value(combined) + if isinstance(value, str): + if value.startswith("s.") and s_row is not None: + return s_row.get(value[2:]) + if value.startswith("t.") and t_row is not None: + return t_row.get(value[2:]) + return value diff --git a/paimon-python/pypaimon/tests/index_manifest_write_test.py b/paimon-python/pypaimon/tests/index_manifest_write_test.py new file mode 100644 index 000000000000..7107fe2fa59e --- /dev/null +++ b/paimon-python/pypaimon/tests/index_manifest_write_test.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import tempfile +import unittest +import uuid + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.globalindex.global_index_meta import GlobalIndexMeta +from pypaimon.index.index_file_meta import IndexFileMeta +from pypaimon.manifest.index_manifest_entry import IndexManifestEntry +from pypaimon.manifest.index_manifest_file import IndexManifestFile +from pypaimon.table.row.generic_row import GenericRow + + +class IndexManifestWriteTest(unittest.TestCase): + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('vec', pa.string()), + ]) + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _table(self): + name = f'default.idx_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(self.pa_schema) + self.catalog.create_table(name, s, False) + return self.catalog.get_table(name) + + def _entry(self, file_name, field_id, meta=b'm'): + partition = GenericRow([], []) + index_file = IndexFileMeta( + index_type='BTREE', + file_name=file_name, + file_size=123, + row_count=10, + global_index_meta=GlobalIndexMeta( + row_range_start=0, + row_range_end=10, + index_field_id=field_id, + extra_field_ids=[field_id + 1], + index_meta=meta, + ), + ) + return IndexManifestEntry(kind=0, partition=partition, bucket=0, index_file=index_file) + + def test_write_read_roundtrip(self): + imf = IndexManifestFile(self._table()) + name = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)]) + out = imf.read(name) + self.assertEqual(2, len(out)) + by_name = {e.index_file.file_name: e for e in out} + a = by_name['idx-a'] + self.assertEqual('BTREE', a.index_file.index_type) + self.assertEqual(123, a.index_file.file_size) + self.assertEqual(10, a.index_file.row_count) + self.assertEqual(0, a.kind) + gim = a.index_file.global_index_meta + self.assertEqual(1, gim.index_field_id) + self.assertEqual(0, gim.row_range_start) + self.assertEqual(10, gim.row_range_end) + self.assertEqual([2], gim.extra_field_ids) + self.assertEqual(b'm', bytes(gim.index_meta)) + + def test_combine_drops_named_files(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)]) + deletes = [self._entry('idx-a', 1)] + new_name = imf.combine(previous, deletes) + self.assertNotEqual(previous, new_name) + survivors = {e.index_file.file_name for e in imf.read(new_name)} + self.assertEqual({'idx-b'}, survivors) + + def test_combine_unknown_delete_is_noop_on_content(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1)]) + new_name = imf.combine(previous, [self._entry('idx-zzz', 9)]) + survivors = {e.index_file.file_name for e in imf.read(new_name)} + self.assertEqual({'idx-a'}, survivors) + + def test_combine_empty_deletes_returns_previous(self): + imf = IndexManifestFile(self._table()) + previous = imf.write([self._entry('idx-a', 1)]) + self.assertEqual(previous, imf.combine(previous, [])) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py new file mode 100644 index 000000000000..6479c921056f --- /dev/null +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -0,0 +1,1043 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest +import uuid + +import pyarrow as pa +import ray + +from pypaimon import CatalogFactory, Schema +from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into + + +class RayDataEvolutionMergeIntoTest(unittest.TestCase): + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + + de_options = { + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog_options = {'warehouse': cls.warehouse} + cls.catalog = CatalogFactory.create(cls.catalog_options) + cls.catalog.create_database('default', True) + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, num_cpus=2) + + @classmethod + def tearDownClass(cls): + try: + if ray.is_initialized(): + ray.shutdown() + except Exception: + pass + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_table(self, options=None): + opts = options if options is not None else self.de_options + name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(self.pa_schema, options=opts) + self.catalog.create_table(name, s, False) + return name + + def _source(self, ids=(1,)): + return pa.Table.from_pydict( + { + 'id': pa.array(list(ids), type=pa.int32()), + 'name': ['x'] * len(ids), + 'age': [10] * len(ids), + }, + schema=self.pa_schema, + ) + + def _write(self, target, data): + table = self.catalog.get_table(target) + wb = table.new_batch_write_builder() + writer = wb.new_write() + writer.write_arrow(data) + wb.new_commit().commit(writer.prepare_commit()) + writer.close() + + def _read_sorted(self, target): + table = self.catalog.get_table(target) + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + return rb.new_read().to_arrow(splits).sort_by('id').to_pydict() + + def _snapshot_id(self, target): + table = self.catalog.get_table(target) + snap = table.snapshot_manager().get_latest_snapshot() + return snap.id if snap is not None else None + + def test_no_clause_raises(self): + target = self._create_table() + with self.assertRaises(ValueError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + ) + + def test_non_de_table_rejected(self): + target = self._create_table(options={'row-tracking.enabled': 'true'}) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + self.assertIn('data-evolution.enabled', str(ctx.exception)) + + def test_no_row_tracking_rejected(self): + target = self._create_table(options={'data-evolution.enabled': 'true'}) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + self.assertIn('row-tracking.enabled', str(ctx.exception)) + + def test_source_missing_on_col_raises(self): + target = self._create_table() + bad_source = pa.Table.from_pydict( + {'name': ['x'], 'age': [10]}, + schema=pa.schema([('name', pa.string()), ('age', pa.int32())]), + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=bad_source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + self.assertIn("'id'", str(ctx.exception)) + + def test_matched_update_star(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b2', 'c2', 'd'], + 'age': pa.array([22, 33, 40], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c2']) + self.assertEqual(out['age'], [10, 22, 33]) + + def test_matched_update_dict(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age'})], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [10, 99]) + + def test_matched_update_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([5, 100, 50], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched( + update={'age': 's.age'}, + condition="s.age > t.age", + ), + ], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['age'], [10, 100, 50]) + + def test_invalid_condition_expression_raises(self): + with self.assertRaises(ValueError): + merge_into( + target=self._create_table(), + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age'}, + condition="evil(1)")], + ) + + def test_matched_multiple_clauses_first_match_wins(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['s1', 's2', 's3'], + 'age': pa.array([5, 25, 100], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched( + update={'age': 1}, + condition="s.age < t.age", + ), + WhenMatched(update={'age': 999}), + ], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['age'], [1, 999, 999]) + + def test_matched_partial_clause_falls_back_to_target(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([42], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['new'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update={'name': 's.name'}), + WhenMatched(update={'age': 's.age'}), + ], + ) + out = self._read_sorted(target) + self.assertEqual(out['name'], ['new']) + self.assertEqual(out['age'], [42]) + + def test_not_matched_insert_appends_unmatched(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b2', 'c2', 'd'], + 'age': pa.array([22, 33, 40], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3, 4]) + self.assertEqual(out['name'], ['a', 'b', 'c', 'd']) + self.assertEqual(out['age'], [10, 20, 30, 40]) + + def test_not_matched_insert_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b', 'c', 'd'], + 'age': pa.array([5, 50, 100], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched( + insert='*', + condition="s.age >= 50", + ), + ], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 3, 4]) + self.assertEqual(out['age'], [10, 50, 100]) + + def test_not_matched_multiple_clauses_first_match_wins(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3], type=pa.int32()), + 'name': ['b', 'c'], + 'age': pa.array([5, 99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched( + insert={'id': 's.id', 'name': 'small', 'age': 1}, + condition="s.age < 10", + ), + WhenNotMatched(insert={'id': 's.id', 'name': 'big', 'age': 2}), + ], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'small', 'big']) + self.assertEqual(out['age'], [10, 1, 2]) + + def test_combined_update_and_insert(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + metrics = merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c']) + self.assertEqual(out['age'], [10, 22, 30]) + self.assertEqual(metrics, {'num_updated': 1, 'num_inserted': 1}) + + def test_combined_matched_clause_condition_no_merge_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['n1', 'n2', 'n3'], + 'age': pa.array([100, 5, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched( + update={'name': 's.name'}, + condition="s.age > 50", + ) + ], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + rows = sorted(zip(out['id'], out['name'], out['age'])) + self.assertEqual(rows, [(1, 'n1', 10), (2, 'b', 20), (3, 'n3', 30)]) + + def test_on_with_renamed_columns(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update={'age': 's.age'})], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['age'], [10, 22]) + + def test_on_with_renamed_columns_star(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c']) + self.assertEqual(out['age'], [10, 22, 30]) + + def test_insert_into_empty_target(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + self.assertEqual(out['age'], [10, 20, 30]) + + def test_insert_dict_fills_unspecified_with_null(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2], type=pa.int32()), + 'name': ['source-name'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[WhenNotMatched(insert={'id': 's.id', 'age': 99})], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', None]) + self.assertEqual(out['age'], [10, 99]) + + def test_multi_source_match_raises_by_default(self): + # One target row matched by several source rows: the winning value is + # undefined (Spark DE's checkCardinality=false), so we refuse by default. + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([100, 200], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + with self.assertRaises(Exception) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + self.assertIn("multiple source rows", str(ctx.exception)) + + def test_multi_source_match_allow_keeps_first(self): + # Opt-in: allow_multiple_matches keeps the first match deterministically. + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([100, 200], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + allow_multiple_matches=True, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1]) + # One source row wins; we don't pin which. + self.assertIn(out['name'][0], ['x', 'y']) + self.assertIn(out['age'][0], [100, 200]) + + def test_blob_update_is_rejected(self): + import types + + from pypaimon.ray.data_evolution_merge_into import _reject_blob_updates + from pypaimon.schema.data_types import AtomicType, DataField + + fake_table = types.SimpleNamespace( + table_schema=types.SimpleNamespace( + fields=[ + DataField(0, 'id', AtomicType('INT')), + DataField(1, 'payload', AtomicType('BLOB')), + ] + ) + ) + with self.assertRaises(NotImplementedError): + _reject_blob_updates(fake_table, {'payload'}) + _reject_blob_updates(fake_table, {'id'}) + + def test_combined_writes_single_snapshot(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + before = self._snapshot_id(target) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3], type=pa.int32()), + 'name': ['b2', 'c'], + 'age': pa.array([22, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + ) + + after = self._snapshot_id(target) + self.assertEqual(after, before + 1) + + def test_self_merge_via_normal_join(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched( + update={'age': lambda r: r['t.age'] + 1}, + ), + ], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['age'], [11, 21, 31]) + + def test_matched_update_can_change_on_column(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['x'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['y'], + 'age': pa.array([20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'id': 999, 'name': 'y'})], + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [999]) + self.assertEqual(out['name'], ['y']) + self.assertEqual(out['age'], [10]) + + def test_empty_target_matched_update_is_noop(self): + target = self._create_table() + before = self._snapshot_id(target) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*')], + ) + + self.assertEqual(self._snapshot_id(target), before) + + +class RayMergeIntoGlobalIndexGateTest(unittest.TestCase): + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + + de_options = { + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + } + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _table(self): + name = f'default.gidx_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(self.pa_schema, options=self.de_options) + self.catalog.create_table(name, s, False) + return self.catalog.get_table(name) + + def _entry(self, table, column, partition_values=()): + from pypaimon.globalindex.global_index_meta import GlobalIndexMeta + from pypaimon.index.index_file_meta import IndexFileMeta + from pypaimon.manifest.index_manifest_entry import IndexManifestEntry + from pypaimon.table.row.generic_row import GenericRow + + field_id = next(f.id for f in table.fields if f.name == column) + index_file = IndexFileMeta( + index_type='BTREE', file_name=f'idx-{column}', file_size=1, row_count=1, + global_index_meta=GlobalIndexMeta( + row_range_start=0, row_range_end=1, index_field_id=field_id, + ), + ) + return IndexManifestEntry( + kind=0, partition=GenericRow(list(partition_values), []), + bucket=0, index_file=index_file, + ) + + def _update_msg(self, partition=()): + from pypaimon.write.commit_message import CommitMessage + return CommitMessage(partition=partition, bucket=0, new_files=[]) + + def test_update_throw_error_raises(self): + from unittest.mock import patch + from pypaimon.ray import data_evolution_merge_into as m + + table = self._table() + with patch.object(m, '_scan_global_index_entries', + return_value=[self._entry(table, 'age')]): + with self.assertRaises(NotImplementedError): + m._apply_global_index_update_action( + table, object(), ['age'], [self._update_msg()], + m.GLOBAL_INDEX_ACTION_THROW_ERROR, + ) + + def test_update_drop_returns_delete_msgs(self): + from unittest.mock import patch + from pypaimon.ray import data_evolution_merge_into as m + + table = self._table() + with patch.object(m, '_scan_global_index_entries', + return_value=[self._entry(table, 'age')]): + msgs = m._apply_global_index_update_action( + table, object(), ['age'], [self._update_msg()], + m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX, + ) + self.assertEqual(1, len(msgs)) + self.assertFalse(msgs[0].is_empty()) + self.assertEqual('idx-age', msgs[0].index_files[0].index_file.file_name) + self.assertEqual(1, msgs[0].index_files[0].kind) + + def test_update_unaffected_column_is_noop(self): + from unittest.mock import patch + from pypaimon.ray import data_evolution_merge_into as m + + table = self._table() + with patch.object(m, '_scan_global_index_entries', + return_value=[self._entry(table, 'age')]): + msgs = m._apply_global_index_update_action( + table, object(), ['name'], [self._update_msg()], + m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX, + ) + self.assertEqual([], msgs) + + def test_update_untouched_partition_is_noop(self): + from unittest.mock import patch + from pypaimon.ray import data_evolution_merge_into as m + + table = self._table() + entry = self._entry(table, 'age', partition_values=('EU',)) + with patch.object(m, '_scan_global_index_entries', return_value=[entry]): + msgs = m._apply_global_index_update_action( + table, object(), ['age'], [self._update_msg(partition=('US',))], + m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX, + ) + self.assertEqual([], msgs) + + +class TargetProjectionTest(unittest.TestCase): + + def _clause(self, spec, condition=None): + from pypaimon.ray import data_evolution_merge_into as m + from pypaimon.ray.condition_expr import parse + return m._NormalizedClause( + spec=spec, condition=parse(condition) if condition else None + ) + + def test_unconditional_set_excludes_target_update_col(self): + from pypaimon.ray import data_evolution_merge_into as m + cols = m._resolve_target_projection( + [self._clause({'feature': 's.feature'})], + ['id'], ['feature'], ['id', 'feature', 'image'], + ) + self.assertEqual(['id'], cols) + + def test_condition_columns_are_projected(self): + from pypaimon.ray import data_evolution_merge_into as m + cols = m._resolve_target_projection( + [self._clause({'feature': 's.feature'}, condition="t.age > 0")], + ['id'], ['feature'], ['id', 'age', 'feature', 'image'], + ) + self.assertEqual(['id', 'age'], cols) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/write/commit_message.py b/paimon-python/pypaimon/write/commit_message.py index 7bce06d8ab13..db6d20ff1fae 100644 --- a/paimon-python/pypaimon/write/commit_message.py +++ b/paimon-python/pypaimon/write/commit_message.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. -from dataclasses import dataclass -from typing import List, Tuple, Optional +from dataclasses import dataclass, field +from typing import List, Tuple, Optional, TYPE_CHECKING from pypaimon.manifest.schema.data_file_meta import DataFileMeta +if TYPE_CHECKING: + from pypaimon.manifest.index_manifest_entry import IndexManifestEntry + @dataclass class CommitMessage: @@ -27,6 +30,7 @@ class CommitMessage: bucket: int new_files: List[DataFileMeta] check_from_snapshot: Optional[int] = -1 + index_files: List['IndexManifestEntry'] = field(default_factory=list) def is_empty(self): - return not self.new_files + return not self.new_files and not self.index_files diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index 486e28924014..93f0ec82a592 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -142,6 +142,10 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): logger.info("Finished collecting changes, including: %d entries", len(commit_entries)) + index_deletes = [] + for msg in commit_messages: + index_deletes.extend(msg.index_files) + commit_kind = "APPEND" detect_conflicts = False allow_rollback = False @@ -157,7 +161,8 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: commit_entries, detect_conflicts=detect_conflicts, - allow_rollback=allow_rollback) + allow_rollback=allow_rollback, + index_deletes=index_deletes) def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in overwrite mode.""" @@ -243,7 +248,7 @@ def truncate_table(self, commit_identifier: int) -> None: ) def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, - detect_conflicts=False, allow_rollback=False): + detect_conflicts=False, allow_rollback=False, index_deletes=None): retry_count = 0 retry_result = None @@ -254,7 +259,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, # No entries to commit (e.g. drop_partitions with no matching data): skip commit # to avoid creating manifest/snapshot with empty partition_stats (causes read errors). - if not commit_entries: + if not commit_entries and not index_deletes: break result = self._try_commit_once( @@ -265,6 +270,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, latest_snapshot=latest_snapshot, detect_conflicts=detect_conflicts, allow_rollback=allow_rollback, + index_deletes=index_deletes, ) if result.is_success(): @@ -316,7 +322,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str commit_entries: List[ManifestEntry], commit_identifier: int, latest_snapshot: Optional[Snapshot], detect_conflicts: bool = False, - allow_rollback: bool = False) -> CommitResult: + allow_rollback: bool = False, + index_deletes=None) -> CommitResult: start_millis = int(time.time() * 1000) if self._is_duplicate_commit(retry_result, latest_snapshot, commit_identifier, commit_kind): return SuccessResult() @@ -327,6 +334,7 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # process new_manifest new_manifest_file = f"manifest-{str(uuid.uuid4())}-0" + new_index_manifest = None # process snapshot new_snapshot_id = latest_snapshot.id + 1 if latest_snapshot else 1 @@ -384,6 +392,13 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str index_manifest = None if latest_snapshot and commit_kind == "APPEND": index_manifest = latest_snapshot.index_manifest + if index_deletes: + from pypaimon.manifest.index_manifest_file import IndexManifestFile + previous_index_manifest = index_manifest + index_manifest = IndexManifestFile(self.table).combine( + previous_index_manifest, index_deletes) + if index_manifest != previous_index_manifest: + new_index_manifest = index_manifest snapshot_data = Snapshot( version=3, @@ -403,7 +418,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # Generate partition statistics for the commit statistics = self._generate_partition_statistics(commit_entries) except Exception as e: - self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list) + self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list, + new_index_manifest) logger.warning(f"Exception occurs when preparing snapshot: {e}", exc_info=True) raise RuntimeError(f"Failed to prepare snapshot: {e}") @@ -423,7 +439,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str commit_kind, commit_time_s, ) - self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list) + self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list, + new_index_manifest) return RetryResult(latest_snapshot, None) except Exception as e: # Commit exception, not sure about the situation and should not clean up the files @@ -604,10 +621,14 @@ def _commit_retry_wait(self, retry_count: int): def _cleanup_preparation_failure(self, delta_manifest_list: Optional[str], - base_manifest_list: Optional[str]): + base_manifest_list: Optional[str], + index_manifest: Optional[str] = None): try: manifest_path = self.manifest_list_manager.manifest_path + if index_manifest: + self.table.file_io.delete_quietly(f"{manifest_path}/{index_manifest}") + if delta_manifest_list: manifest_files = self.manifest_list_manager.read(delta_manifest_list) for manifest_meta in manifest_files: diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py index fe2fb9a64b79..55b755e2d4d5 100644 --- a/paimon-python/pypaimon/write/table_update.py +++ b/paimon-python/pypaimon/write/table_update.py @@ -109,8 +109,6 @@ def with_update_type(self, update_cols: List[str]): for col in update_cols: if col not in self.table.field_names: raise ValueError(f"Column {col} is not in table schema.") - if len(update_cols) == len(self.table.field_names): - update_cols = None self.update_cols = update_cols return self diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py index ac9c68c3623b..ab61ed21505e 100644 --- a/paimon-python/pypaimon/write/table_update_by_row_id.py +++ b/paimon-python/pypaimon/write/table_update_by_row_id.py @@ -42,19 +42,30 @@ class TableUpdateByRowId: FIRST_ROW_ID_COLUMN = '_FIRST_ROW_ID' - def __init__(self, table, commit_user: str, commit_identifier: int): + def __init__( + self, table, commit_user: str, commit_identifier: int, + precomputed_files_info: Optional[Tuple[ + int, List[int], + Dict[int, Tuple[DataSplit, List[DataFileMeta]]], + int, + ]] = None, + ): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.commit_user = commit_user self.commit_identifier = commit_identifier - # Snapshot the current state once: a single ``first_row_id -> (split, files)`` - # map is enough to drive every downstream lookup (partition, row-count, read). - (self.snapshot_id, - self.first_row_ids, - self._first_row_id_index, - self.total_row_count) = self._load_existing_files_info() + if precomputed_files_info is not None: + (self.snapshot_id, + self.first_row_ids, + self._first_row_id_index, + self.total_row_count) = precomputed_files_info + else: + (self.snapshot_id, + self.first_row_ids, + self._first_row_id_index, + self.total_row_count) = self._load_existing_files_info() self.commit_messages: List[CommitMessage] = []