diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index 3ee4db328979..074cf768ce6f 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -277,3 +277,48 @@ write_builder = table.new_batch_write_builder().overwrite()
# overwrite partition 'dt=2024-01-01'
write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'})
```
+
+## Merge Into
+
+`merge_into` updates (and optionally inserts) rows of a **data-evolution** table
+from a source, like SQL `MERGE INTO`. Matched rows are updated in place by
+`_ROW_ID`; only the touched columns are rewritten. Requires `ray >= 2.50` and a
+target table with `'data-evolution.enabled'` and `'row-tracking.enabled'` set.
+
+```python
+from pypaimon.ray import merge_into, WhenMatched, WhenNotMatched
+
+metrics = merge_into(
+ target="database_name.table_name",
+ source=ray_dataset, # ray.data.Dataset / pa.Table / pandas / table-name str
+ catalog_options={"warehouse": "/path/to/warehouse"},
+ on=["id"], # or {"target_col": "source_col"} for renamed keys
+ when_matched=[WhenMatched(update={"score": "s.score"})], # or update="*"
+ when_not_matched=[WhenNotMatched(insert="*")], # optional
+)
+print(metrics) # {"num_updated": 3, "num_inserted": 2}
+```
+
+- `update` / `insert`: `"*"` (all columns from source), or a dict mapping target
+ columns to `"s.
"`, `"t."`, or a literal.
+- `condition` (optional): a string expression over `s.` / `t.` using
+ `> < >= <= == != and or not`; only referenced columns are read. Example:
+ `WhenMatched(update={"score": "s.score"}, condition="s.version > t.version")`.
+
+**Parameters:**
+- `on`: key columns, or `{target_col: source_col}` for renamed keys.
+- `num_partitions`: shuffle parallelism for the join and the write; defaults to
+ `max(16, cluster_cpus * 2)`, raise it for large merges.
+- `ray_remote_args`, `concurrency`: scheduling for the insert path.
+- `allow_multiple_matches`: if `False` (default), a target row matched by
+ multiple source rows raises; `True` keeps the first match.
+
+**Returns:** `{"num_updated", "num_inserted"}`.
+
+**Notes:**
+- Blob columns cannot be updated and are never read into the join.
+- Updating a globally-indexed column raises by default; set
+ `'global-index.column-update-action' = 'DROP_PARTITION_INDEX'` to drop the
+ affected index instead (rebuild afterwards).
+- Cost scales with how many data files the updated rows touch; scattered updates
+ over a large table rewrite the updated column of many files.
diff --git a/paimon-python/dev/requirements-dev.txt b/paimon-python/dev/requirements-dev.txt
index d4e9a0645b17..9ef88817f726 100644
--- a/paimon-python/dev/requirements-dev.txt
+++ b/paimon-python/dev/requirements-dev.txt
@@ -21,8 +21,9 @@
duckdb==1.3.2
flake8==4.0.1
pytest~=7.0
-# Ray: 2.48+ has no wheel for Python 3.8; use 2.10.0 on 3.8, 2.48.0 on 3.9+
-ray>=2.10.0
+# merge_into needs Dataset.join (added in Ray 2.50). Python 3.8 has no 2.50 wheel.
+ray>=2.10.0; python_version < "3.9"
+ray>=2.50.0; python_version >= "3.9"
requests
parameterized
# Vortex 0.71.0 regresses native predicate pushdown on single-row files.
diff --git a/paimon-python/pypaimon/common/options/core_options.py b/paimon-python/pypaimon/common/options/core_options.py
index 2d140b9539b6..06b9b7e86967 100644
--- a/paimon-python/pypaimon/common/options/core_options.py
+++ b/paimon-python/pypaimon/common/options/core_options.py
@@ -398,6 +398,16 @@ class CoreOptions:
)
)
+ GLOBAL_INDEX_COLUMN_UPDATE_ACTION: ConfigOption[str] = (
+ ConfigOptions.key("global-index.column-update-action")
+ .string_type()
+ .default_value("THROW_ERROR")
+ .with_description(
+ "Defines the action to take when an update modifies columns that "
+ "are covered by a global index. THROW_ERROR or DROP_PARTITION_INDEX."
+ )
+ )
+
LOCAL_CACHE_ENABLED: ConfigOption[bool] = (
ConfigOptions.key("local-cache.enabled")
.boolean_type()
@@ -652,6 +662,9 @@ def row_tracking_enabled(self, default=None):
def data_evolution_enabled(self, default=None):
return self.options.get(CoreOptions.DATA_EVOLUTION_ENABLED, default)
+ def global_index_column_update_action(self, default=None):
+ return self.options.get(CoreOptions.GLOBAL_INDEX_COLUMN_UPDATE_ACTION, default)
+
def deletion_vectors_enabled(self, default=None):
return self.options.get(CoreOptions.DELETION_VECTORS_ENABLED, default)
diff --git a/paimon-python/pypaimon/manifest/index_manifest_entry.py b/paimon-python/pypaimon/manifest/index_manifest_entry.py
index 7a5e7d1a4f53..9ec5f103dba5 100644
--- a/paimon-python/pypaimon/manifest/index_manifest_entry.py
+++ b/paimon-python/pypaimon/manifest/index_manifest_entry.py
@@ -41,22 +41,3 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.kind, tuple(self.partition.values),
self.bucket, self.index_file))
-
-
-INDEX_MANIFEST_ENTRY = {
- "type": "record",
- "name": "IndexManifestEntry",
- "fields": [
- {"name": "_VERSION", "type": "int"},
- {"name": "_KIND", "type": "byte"},
- {"name": "_PARTITION", "type": "bytes"},
- {"name": "_BUCKET", "type": "int"},
- {"name": "_INDEX_TYPE", "type": "string"},
- {"name": "_FILE_NAME", "type": "string"},
- {"name": "_FILE_SIZE", "type": "long"},
- {"name": "_ROW_COUNT", "type": "long"},
- {"name": "_DELETIONS_VECTORS_RANGES", "type": {"type": "array", "elementType": "DeletionVectorMeta"}},
- {"name": "_EXTERNAL_PATH", "type": ["null", "string"]},
- {"name": "_GLOBAL_INDEX", "type": "GlobalIndexMeta"}
- ]
-}
diff --git a/paimon-python/pypaimon/manifest/index_manifest_file.py b/paimon-python/pypaimon/manifest/index_manifest_file.py
index 4e65e95e0cb1..5312b0975255 100644
--- a/paimon-python/pypaimon/manifest/index_manifest_file.py
+++ b/paimon-python/pypaimon/manifest/index_manifest_file.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import uuid
from io import BytesIO
from typing import List, Optional
@@ -24,11 +25,60 @@
from pypaimon.index.deletion_vector_meta import DeletionVectorMeta
from pypaimon.index.index_file_meta import IndexFileMeta
from pypaimon.manifest.index_manifest_entry import IndexManifestEntry
-from pypaimon.table.row.generic_row import GenericRowDeserializer
+from pypaimon.table.row.generic_row import (GenericRowDeserializer,
+ GenericRowSerializer)
+from pypaimon.utils.file_store_path_factory import FileStorePathFactory
+
+_DELETION_VECTOR_META_SCHEMA = {
+ "type": "record",
+ "name": "DeletionVectorMeta",
+ "fields": [
+ {"name": "f0", "type": "string"},
+ {"name": "f1", "type": "long"},
+ {"name": "f2", "type": "int"},
+ {"name": "_CARDINALITY", "type": ["null", "long"], "default": None},
+ ],
+}
+
+_GLOBAL_INDEX_META_SCHEMA = {
+ "type": "record",
+ "name": "GlobalIndexMeta",
+ "fields": [
+ {"name": "_ROW_RANGE_START", "type": "long"},
+ {"name": "_ROW_RANGE_END", "type": "long"},
+ {"name": "_INDEX_FIELD_ID", "type": "int"},
+ {"name": "_EXTRA_FIELD_IDS",
+ "type": ["null", {"type": "array", "items": "int"}], "default": None},
+ {"name": "_INDEX_META", "type": ["null", "bytes"], "default": None},
+ ],
+}
+
+INDEX_MANIFEST_ENTRY_SCHEMA = {
+ "type": "record",
+ "name": "IndexManifestEntry",
+ "fields": [
+ {"name": "_VERSION", "type": "int"},
+ {"name": "_KIND", "type": "int"},
+ {"name": "_PARTITION", "type": "bytes"},
+ {"name": "_BUCKET", "type": "int"},
+ {"name": "_INDEX_TYPE", "type": "string"},
+ {"name": "_FILE_NAME", "type": "string"},
+ {"name": "_FILE_SIZE", "type": "long"},
+ {"name": "_ROW_COUNT", "type": "long"},
+ {"name": "_DELETIONS_VECTORS_RANGES",
+ "type": ["null", {"type": "array", "items": _DELETION_VECTOR_META_SCHEMA}],
+ "default": None},
+ {"name": "_EXTERNAL_PATH", "type": ["null", "string"], "default": None},
+ {"name": "_GLOBAL_INDEX",
+ "type": ["null", _GLOBAL_INDEX_META_SCHEMA], "default": None},
+ ],
+}
+
+_INDEX_ENTRY_VERSION = 1
class IndexManifestFile:
- """Index manifest file reader for reading index manifest entries."""
+ """Index manifest file reader/writer for index manifest entries."""
DELETION_VECTORS_INDEX = "DELETION_VECTORS"
@@ -172,5 +222,73 @@ def _parse_global_index_meta(self, global_index_record) -> Optional[GlobalIndexM
row_range_start=global_index_record.get('_ROW_RANGE_START', 0),
row_range_end=global_index_record.get('_ROW_RANGE_END', 0),
index_field_id=global_index_record.get('_INDEX_FIELD_ID', 0),
+ extra_field_ids=global_index_record.get('_EXTRA_FIELD_IDS'),
index_meta=global_index_record.get('_INDEX_META')
)
+
+ def combine(
+ self,
+ previous_name: Optional[str],
+ deletes: List[IndexManifestEntry],
+ ) -> Optional[str]:
+ """Apply DELETE entries to the previous index manifest and write a new one.
+
+ Mirrors Java GlobalIndexCombiner: the stored manifest only holds ADD
+ entries; deleting means dropping the entries whose index file name
+ appears in *deletes*. Returns the new manifest file name, or
+ *previous_name* unchanged when there is nothing to delete.
+ """
+ if not deletes:
+ return previous_name
+ previous = self.read(previous_name) if previous_name else []
+ delete_names = {e.index_file.file_name for e in deletes}
+ survivors = [e for e in previous if e.index_file.file_name not in delete_names]
+ return self.write(survivors)
+
+ def write(self, entries: List[IndexManifestEntry]) -> str:
+ """Serialize *entries* to a new Avro index manifest, return its name."""
+ file_name = f"{FileStorePathFactory.INDEX_MANIFEST_PREFIX}{uuid.uuid4()}"
+ path = f"{self.manifest_path}/{file_name}"
+ records = [self._to_avro_record(e) for e in entries]
+ try:
+ buffer = BytesIO()
+ fastavro.writer(buffer, INDEX_MANIFEST_ENTRY_SCHEMA, records)
+ with self.file_io.new_output_stream(path) as output_stream:
+ output_stream.write(buffer.getvalue())
+ except Exception as e:
+ self.file_io.delete_quietly(path)
+ raise RuntimeError(f"Failed to write index manifest file: {e}") from e
+ return file_name
+
+ def _to_avro_record(self, entry: IndexManifestEntry) -> dict:
+ index_file = entry.index_file
+ dv_ranges = None
+ if index_file.dv_ranges:
+ dv_ranges = [
+ {"f0": dv.data_file_name, "f1": dv.offset, "f2": dv.length,
+ "_CARDINALITY": dv.cardinality}
+ for dv in index_file.dv_ranges.values()
+ ]
+ global_index = None
+ if index_file.global_index_meta is not None:
+ gim = index_file.global_index_meta
+ global_index = {
+ "_ROW_RANGE_START": gim.row_range_start,
+ "_ROW_RANGE_END": gim.row_range_end,
+ "_INDEX_FIELD_ID": gim.index_field_id,
+ "_EXTRA_FIELD_IDS": gim.extra_field_ids,
+ "_INDEX_META": gim.index_meta,
+ }
+ return {
+ "_VERSION": _INDEX_ENTRY_VERSION,
+ "_KIND": entry.kind,
+ "_PARTITION": GenericRowSerializer.to_bytes(entry.partition),
+ "_BUCKET": entry.bucket,
+ "_INDEX_TYPE": index_file.index_type,
+ "_FILE_NAME": index_file.file_name,
+ "_FILE_SIZE": index_file.file_size,
+ "_ROW_COUNT": index_file.row_count,
+ "_DELETIONS_VECTORS_RANGES": dv_ranges,
+ "_EXTERNAL_PATH": index_file.external_path,
+ "_GLOBAL_INDEX": global_index,
+ }
diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py
index f36eb0253dd8..9161f3cbb3b7 100644
--- a/paimon-python/pypaimon/ray/__init__.py
+++ b/paimon-python/pypaimon/ray/__init__.py
@@ -16,5 +16,16 @@
# under the License.
from pypaimon.ray.ray_paimon import read_paimon, write_paimon
+from pypaimon.ray.data_evolution_merge_into import (
+ WhenMatched,
+ WhenNotMatched,
+ merge_into,
+)
-__all__ = ["read_paimon", "write_paimon"]
+__all__ = [
+ "read_paimon",
+ "write_paimon",
+ "merge_into",
+ "WhenMatched",
+ "WhenNotMatched",
+]
diff --git a/paimon-python/pypaimon/ray/condition_expr.py b/paimon-python/pypaimon/ray/condition_expr.py
new file mode 100644
index 000000000000..c20c3c111fb8
--- /dev/null
+++ b/paimon-python/pypaimon/ray/condition_expr.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ast
+import operator
+from typing import Mapping, Set
+
+_PREFIXES = ("s", "t")
+
+_COMPARE_OPS = {
+ ast.Lt: operator.lt,
+ ast.LtE: operator.le,
+ ast.Gt: operator.gt,
+ ast.GtE: operator.ge,
+}
+
+_ALLOWED_NODES = (
+ ast.BoolOp, ast.And, ast.Or,
+ ast.UnaryOp, ast.Not, ast.USub, ast.UAdd,
+ ast.Compare, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
+ ast.Constant, ast.Attribute, ast.Name, ast.Load,
+)
+
+
+class ConditionExpr:
+ """A parsed merge condition over the joined row.
+
+ Columns are referenced as ``s.col`` / ``t.col``; the evaluator reads them
+ from a combined ``{"s.col": ..., "t.col": ...}`` mapping. Only comparisons,
+ boolean and/or/not, and literals are supported, so the expression is safe to
+ evaluate (no ``eval``) and its referenced columns can be extracted statically.
+ """
+
+ def __init__(self, source: str, body: ast.AST):
+ self.source = source
+ self._body = body
+
+ def eval(self, combined: Mapping) -> bool:
+ return bool(_eval(self._body, combined))
+
+ def target_columns(self) -> Set[str]:
+ return self._columns("t")
+
+ def source_columns(self) -> Set[str]:
+ return self._columns("s")
+
+ def _columns(self, prefix: str) -> Set[str]:
+ cols = set()
+ for node in ast.walk(self._body):
+ if (isinstance(node, ast.Attribute)
+ and isinstance(node.value, ast.Name)
+ and node.value.id == prefix):
+ cols.add(node.attr)
+ return cols
+
+
+def parse(source: str) -> ConditionExpr:
+ try:
+ tree = ast.parse(source, mode="eval")
+ except SyntaxError as e:
+ raise ValueError(f"Invalid merge condition {source!r}: {e}")
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Expression):
+ continue
+ if not isinstance(node, _ALLOWED_NODES):
+ raise ValueError(
+ f"Unsupported syntax in merge condition {source!r}: "
+ f"{type(node).__name__}. Only comparisons of s./t. columns and "
+ f"literals combined with and/or/not are allowed."
+ )
+ if isinstance(node, ast.Attribute):
+ if not (isinstance(node.value, ast.Name) and node.value.id in _PREFIXES):
+ raise ValueError(
+ f"Column reference in merge condition {source!r} must be "
+ f"'s.' or 't.'."
+ )
+ if isinstance(node, ast.Name) and node.id not in _PREFIXES:
+ raise ValueError(
+ f"Unknown name {node.id!r} in merge condition {source!r}; "
+ f"only 's' and 't' are allowed."
+ )
+ return ConditionExpr(source, tree.body)
+
+
+def _eval(node, combined):
+ if isinstance(node, ast.BoolOp):
+ values = (_eval(v, combined) for v in node.values)
+ if isinstance(node.op, ast.And):
+ return all(values)
+ return any(values)
+ if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
+ return not _eval(node.operand, combined)
+ if isinstance(node, ast.Compare):
+ left = _operand(node.left, combined)
+ ok = True
+ for op, comparator in zip(node.ops, node.comparators):
+ right = _operand(comparator, combined)
+ ok = ok and _apply(op, left, right)
+ left = right
+ return ok
+ return _operand(node, combined)
+
+
+def _operand(node, combined):
+ if isinstance(node, ast.Constant):
+ return node.value
+ if isinstance(node, ast.Attribute):
+ return combined.get(f"{node.value.id}.{node.attr}")
+ if isinstance(node, ast.UnaryOp):
+ value = _operand(node.operand, combined)
+ return -value if isinstance(node.op, ast.USub) else value
+ return _eval(node, combined)
+
+
+def _apply(op, left, right):
+ if isinstance(op, ast.Eq):
+ return left == right
+ if isinstance(op, ast.NotEq):
+ return left != right
+ if left is None or right is None:
+ return False
+ return _COMPARE_OPS[type(op)](left, right)
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
new file mode 100644
index 000000000000..2876a0a5c838
--- /dev/null
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -0,0 +1,1056 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""MERGE INTO ... USING ... for Paimon data-evolution tables via Ray Datasets."""
+
+from dataclasses import dataclass
+from typing import (
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
+
+import pyarrow as pa
+
+from pypaimon.ray.condition_expr import ConditionExpr
+from pypaimon.ray.condition_expr import parse as parse_condition
+
+SetSpec = Union[str, Dict[str, Any]]
+Condition = str
+OnSpec = Union[Sequence[str], Mapping[str, str]]
+
+
+@dataclass
+class WhenMatched:
+ update: SetSpec
+ condition: Optional[Condition] = None
+
+
+@dataclass
+class WhenNotMatched:
+ insert: SetSpec
+ condition: Optional[Condition] = None
+
+
+@dataclass
+class _NormalizedClause:
+ spec: Dict[str, Any]
+ condition: Optional[ConditionExpr]
+
+
+def merge_into(
+ target: str,
+ source: Any,
+ catalog_options: Dict[str, str],
+ *,
+ on: OnSpec,
+ when_matched: Sequence[WhenMatched] = (),
+ when_not_matched: Sequence[WhenNotMatched] = (),
+ num_partitions: Optional[int] = None,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+ concurrency: Optional[int] = None,
+ allow_multiple_matches: bool = False,
+) -> Dict[str, int]:
+ _require_ray_join()
+ num_partitions = _resolve_num_partitions(num_partitions)
+ when_matched = list(when_matched)
+ when_not_matched = list(when_not_matched)
+ if not when_matched and not when_not_matched:
+ raise ValueError(
+ "At least one of when_matched or when_not_matched must be non-empty."
+ )
+
+ target_on_cols, source_on_cols = _normalize_on(on)
+
+ from pypaimon.catalog.catalog_factory import CatalogFactory
+
+ catalog = CatalogFactory.create(catalog_options)
+ table = catalog.get_table(target)
+ if not table.options.data_evolution_enabled():
+ raise ValueError(
+ f"merge_into requires 'data-evolution.enabled' = 'true' on '{target}'."
+ )
+ if not table.options.row_tracking_enabled():
+ raise ValueError(
+ f"merge_into requires 'row-tracking.enabled' = 'true' on '{target}'."
+ )
+
+ target_field_names = list(table.field_names)
+ on_map = dict(zip(target_on_cols, source_on_cols))
+ matched_specs = [
+ _NormalizedClause(
+ spec=_normalize_set_spec(c.update, target_field_names, on_map),
+ condition=parse_condition(c.condition) if c.condition else None,
+ )
+ for c in when_matched
+ ]
+ not_matched_specs = [
+ _NormalizedClause(
+ spec=_normalize_set_spec(c.insert, target_field_names, on_map),
+ condition=parse_condition(c.condition) if c.condition else None,
+ )
+ for c in when_not_matched
+ ]
+
+ update_cols: set = set()
+ for clause in matched_specs:
+ update_cols.update(clause.spec.keys())
+ _reject_blob_updates(table, update_cols)
+
+ source_ds = _normalize_source(source, catalog_options)
+ _validate_source_on_cols(source_ds, source_on_cols)
+
+ base_snapshot = table.snapshot_manager().get_latest_snapshot()
+
+ global_index_action = (
+ table.options.global_index_column_update_action()
+ or GLOBAL_INDEX_ACTION_THROW_ERROR
+ )
+
+ from pypaimon.schema.data_types import PyarrowFieldParser
+
+ target_pa_schema = PyarrowFieldParser.from_paimon_schema(
+ table.table_schema.fields
+ )
+
+ update_ds = None
+ insert_ds = None
+ update_cols_union: List[str] = []
+
+ # With both clauses on a non-empty target, matched and not-matched routing
+ # share the same source/target equi-join. Build them from one materialized
+ # LEFT_OUTER join instead of reading and shuffling the target table twice.
+ if matched_specs and not_matched_specs and base_snapshot is not None:
+ update_cols_union = _union_update_cols(matched_specs)
+ update_ds, insert_ds = _build_unified_both(
+ target_identifier=target,
+ source_ds=source_ds,
+ target_on=target_on_cols,
+ source_on=source_on_cols,
+ matched_clauses=matched_specs,
+ not_matched_clauses=not_matched_specs,
+ target_field_names=target_field_names,
+ target_pa_schema=target_pa_schema,
+ update_cols=update_cols_union,
+ catalog_options=catalog_options,
+ num_partitions=num_partitions,
+ )
+ else:
+ # Empty target → no rows can match; matched UPDATE is a no-op.
+ if matched_specs and base_snapshot is not None:
+ update_cols_union = _union_update_cols(matched_specs)
+ update_ds = _build_matched_update_ds(
+ target_identifier=target,
+ source_ds=source_ds,
+ target_on=target_on_cols,
+ source_on=source_on_cols,
+ clauses=matched_specs,
+ target_field_names=target_field_names,
+ target_pa_schema=target_pa_schema,
+ update_cols=update_cols_union,
+ catalog_options=catalog_options,
+ num_partitions=num_partitions,
+ )
+
+ if not_matched_specs:
+ # Empty target: nothing can match, so every source row inserts.
+ # Skip all joins (ray's hash join crashes on empty partitions).
+ insert_ds = _build_not_matched_insert_ds(
+ target_identifier=target,
+ source_ds=source_ds,
+ target_on=target_on_cols,
+ source_on=source_on_cols,
+ clauses=not_matched_specs,
+ target_field_names=target_field_names,
+ target_pa_schema=target_pa_schema,
+ catalog_options=catalog_options,
+ num_partitions=num_partitions,
+ target_empty=base_snapshot is None,
+ )
+
+ update_msgs: list = []
+ num_updated = 0
+ if update_ds is not None:
+ update_msgs, num_updated = _distributed_update_apply(
+ update_ds,
+ table,
+ update_cols_union,
+ num_partitions=num_partitions,
+ ray_remote_args=ray_remote_args,
+ allow_multiple_matches=allow_multiple_matches,
+ )
+
+ all_msgs: list = list(update_msgs)
+ num_inserted = 0
+ if insert_ds is not None:
+ insert_msgs = _distributed_write_collect_msgs(
+ insert_ds, table, ray_remote_args=ray_remote_args, concurrency=concurrency
+ )
+ num_inserted = sum(f.row_count for m in insert_msgs for f in m.new_files)
+ all_msgs.extend(insert_msgs)
+ # Mirror Spark's checkUpdateResult: scope the global-index action to the
+ # partitions the update actually wrote and the updated indexed columns.
+ all_msgs.extend(
+ _apply_global_index_update_action(
+ table, base_snapshot, update_cols_union, update_msgs, global_index_action
+ )
+ )
+ if all_msgs:
+ wb = table.new_batch_write_builder()
+ tc = wb.new_commit()
+ tc.commit(all_msgs)
+ tc.close()
+
+ return {"num_updated": num_updated, "num_inserted": num_inserted}
+
+
+def _normalize_on(on: OnSpec) -> Tuple[List[str], List[str]]:
+ if isinstance(on, Mapping):
+ target_cols = list(on.keys())
+ source_cols = list(on.values())
+ else:
+ target_cols = list(on)
+ source_cols = list(on)
+ if not target_cols:
+ raise ValueError("'on' must be non-empty.")
+ return target_cols, source_cols
+
+
+def _build_matched_update_ds(
+ *,
+ target_identifier: str,
+ source_ds,
+ target_on: Sequence[str],
+ source_on: Sequence[str],
+ clauses: List[_NormalizedClause],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+ update_cols: Sequence[str],
+ catalog_options: Dict[str, str],
+ num_partitions: int,
+):
+ from pypaimon.ray.ray_paimon import read_paimon
+ from pypaimon.table.special_fields import SpecialFields
+
+ row_id_name = SpecialFields.ROW_ID.name
+ needed_cols = _resolve_target_projection(
+ clauses, target_on, update_cols, target_field_names,
+ )
+ projection = [row_id_name] + [c for c in needed_cols if c != row_id_name]
+
+ target_ds = read_paimon(target_identifier, catalog_options, projection=projection)
+ update_schema = _build_update_schema(target_pa_schema, update_cols, row_id_name)
+
+ target_renamed = target_ds.rename_columns(
+ {c: f"t.{c}" for c in target_ds.schema().names}
+ )
+ source_schema = source_ds.schema()
+ source_cols = list(source_schema.names) if source_schema is not None else list(source_on)
+ source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols})
+
+ joined = target_renamed.join(
+ source_renamed,
+ join_type="inner",
+ num_partitions=num_partitions,
+ on=tuple(f"t.{c}" for c in target_on),
+ right_on=tuple(f"s.{c}" for c in source_on),
+ )
+
+ captured_clauses = clauses
+ captured_update_cols = list(update_cols)
+ captured_field_names = list(target_field_names)
+ captured_row_id_name = row_id_name
+ captured_on_pairs = list(zip(source_on, target_on))
+ captured_schema = update_schema
+
+ if _clauses_use_vector_fast_path(clauses):
+ first_spec = clauses[0].spec
+
+ def _fast(batch: pa.Table) -> pa.Table:
+ return _vectorized_matched_transform(
+ batch,
+ first_spec,
+ captured_on_pairs,
+ captured_update_cols,
+ captured_row_id_name,
+ captured_schema,
+ )
+
+ return joined.map_batches(_fast, batch_format="pyarrow")
+
+ def _transform(batch: pa.Table) -> pa.Table:
+ return _apply_matched_transform(
+ batch,
+ captured_clauses,
+ captured_on_pairs,
+ captured_update_cols,
+ captured_field_names,
+ captured_row_id_name,
+ captured_schema,
+ )
+
+ return joined.map_batches(_transform, batch_format="pyarrow")
+
+
+def _apply_matched_transform(
+ batch: pa.Table,
+ clauses: List[_NormalizedClause],
+ on_pairs: Sequence[Tuple[str, str]],
+ update_cols: Sequence[str],
+ field_names: Sequence[str],
+ row_id_name: str,
+ update_schema: pa.Schema,
+) -> pa.Table:
+ rows = batch.to_pylist()
+ out_row_ids: list = []
+ out_cols: Dict[str, list] = {c: [] for c in update_cols}
+ for row in rows:
+ s_row = {k[2:]: v for k, v in row.items() if k.startswith("s.")}
+ t_row = {k[2:]: v for k, v in row.items() if k.startswith("t.")}
+ for s_key, t_key in on_pairs:
+ if s_key not in s_row and t_key in t_row:
+ s_row[s_key] = t_row[t_key]
+ combined = _prefixed(s_row, t_row)
+ for clause in clauses:
+ if clause.condition is not None and not clause.condition.eval(combined):
+ continue
+ new_values = _apply_set(clause.spec, s_row, t_row, field_names)
+ out_row_ids.append(t_row[row_id_name])
+ for col in update_cols:
+ out_cols[col].append(new_values.get(col, t_row.get(col)))
+ break
+ return pa.Table.from_pydict(
+ {row_id_name: out_row_ids, **out_cols},
+ schema=update_schema,
+ )
+
+
+def _build_update_schema(
+ target_pa_schema: pa.Schema,
+ update_cols: Sequence[str],
+ row_id_name: str,
+) -> pa.Schema:
+ return pa.schema(
+ [pa.field(row_id_name, pa.int64(), nullable=False)]
+ + [target_pa_schema.field(col) for col in update_cols]
+ )
+
+
+def _distributed_update_apply(
+ update_ds,
+ table,
+ write_update_cols: Sequence[str],
+ *,
+ num_partitions: int,
+ ray_remote_args: Optional[Dict[str, Any]] = None,
+ allow_multiple_matches: bool = False,
+) -> Tuple[list, int]:
+ import numpy as np
+ import pickle
+ import uuid
+
+ from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
+ from pypaimon.table.special_fields import SpecialFields
+ from pypaimon.write.table_update_by_row_id import TableUpdateByRowId
+
+ row_id_name = SpecialFields.ROW_ID.name
+ cols = list(write_update_cols)
+
+ for col in cols:
+ if col not in table.field_names:
+ raise ValueError(f"Column '{col}' is not in target table schema.")
+
+ planner = TableUpdateByRowId(
+ table,
+ "_merge_into_planner_" + uuid.uuid4().hex[:8],
+ BATCH_COMMIT_IDENTIFIER,
+ )
+ sorted_first_row_ids = list(planner.first_row_ids)
+ if not sorted_first_row_ids:
+ return [], 0
+
+ # Broadcast the file-info snapshot to every worker so they skip the
+ # per-task manifest scan and observe a single consistent target view.
+ precomputed_info = (
+ planner.snapshot_id,
+ planner.first_row_ids,
+ planner._first_row_id_index,
+ planner.total_row_count,
+ )
+
+ frid_col = "_FIRST_ROW_ID"
+ captured_sorted = sorted_first_row_ids
+ captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64)
+ first = captured_sorted_arr[0]
+ captured_precomputed = precomputed_info
+ total_row_count = planner.total_row_count
+
+ def _assign_frid(batch: pa.Table) -> pa.Table:
+ if batch.num_rows == 0:
+ return batch.append_column(frid_col, pa.array([], type=pa.int64()))
+ rid_col = batch.column(row_id_name)
+ if rid_col.null_count:
+ raise ValueError(
+ "_ROW_ID is null; planner snapshot is stale "
+ "or matched rows come from a different table."
+ )
+ rids = rid_col.to_numpy(zero_copy_only=False)
+ # Out-of-range _ROW_IDs would silently map via searchsorted wrap-around.
+ out_of_range = (rids < first) | (rids >= total_row_count)
+ if out_of_range.any():
+ bad = rids[out_of_range][0]
+ raise ValueError(
+ f"_ROW_ID {bad} is out of valid range "
+ f"[{first}, {total_row_count}); planner snapshot is stale "
+ f"or matched rows come from a different table."
+ )
+ idx = np.searchsorted(captured_sorted_arr, rids, side="right") - 1
+ frids = captured_sorted_arr[idx]
+ return batch.append_column(frid_col, pa.array(frids, type=pa.int64()))
+
+ with_frid = update_ds.map_batches(_assign_frid, batch_format="pyarrow")
+
+ captured_table = table
+ captured_cols = cols
+
+ def _apply_group(group: pa.Table) -> pa.Table:
+ if group.num_rows == 0:
+ return pa.Table.from_pydict({
+ "msgs_blob": pa.array([], type=pa.binary()),
+ "n_updated": pa.array([], type=pa.int64()),
+ })
+
+ # One target _ROW_ID matched by several source rows. Default: refuse
+ # (the winning value is otherwise undefined, as in Spark DE's
+ # checkCardinality=false). Opt-in keeps the first match deterministically.
+ group_row_ids = group.column(row_id_name).to_pylist()
+ if len(set(group_row_ids)) != len(group_row_ids):
+ if not allow_multiple_matches:
+ raise ValueError(
+ "MERGE matched multiple source rows to the same target "
+ "_ROW_ID. Deduplicate the source, or pass "
+ "allow_multiple_matches=True to keep the first match."
+ )
+ seen: set = set()
+ keep_indices: list = []
+ for i, rid in enumerate(group_row_ids):
+ if rid not in seen:
+ seen.add(rid)
+ keep_indices.append(i)
+ group = group.take(pa.array(keep_indices, type=pa.int64()))
+
+ for_update = group.drop_columns([frid_col])
+ worker = TableUpdateByRowId(
+ captured_table,
+ "_merge_into_shard_" + uuid.uuid4().hex[:8],
+ BATCH_COMMIT_IDENTIFIER,
+ precomputed_files_info=captured_precomputed,
+ )
+ msgs = worker.update_columns(for_update, list(captured_cols))
+ return pa.Table.from_pydict({
+ "msgs_blob": [pickle.dumps(msgs)],
+ "n_updated": pa.array([for_update.num_rows], type=pa.int64()),
+ })
+
+ # One group per target data file (distinct _FIRST_ROW_ID). Drive the write
+ # shuffle with the same num_partitions knob as the join (Spark's single
+ # shuffle.partitions), bounded by the file count so small merges don't spawn
+ # empty reduce tasks and large ones scale past a fixed cap.
+ group_partitions = max(1, min(len(captured_sorted), num_partitions))
+ msgs_ds = with_frid.groupby(frid_col, num_partitions=group_partitions).map_groups(
+ _apply_group, batch_format="pyarrow"
+ )
+
+ all_msgs: list = []
+ num_updated = 0
+ for batch in msgs_ds.iter_batches(batch_format="pyarrow"):
+ for blob in batch.column("msgs_blob").to_pylist():
+ all_msgs.extend(pickle.loads(blob))
+ for n in batch.column("n_updated").to_pylist():
+ num_updated += n
+ return all_msgs, num_updated
+
+
+GLOBAL_INDEX_ACTION_THROW_ERROR = "THROW_ERROR"
+GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX = "DROP_PARTITION_INDEX"
+
+
+def _resolve_num_partitions(num_partitions: Optional[int]) -> int:
+ if num_partitions is not None:
+ return num_partitions
+ try:
+ import ray
+
+ cpus = ray.cluster_resources().get("CPU", 16)
+ return max(16, int(cpus) * 2)
+ except Exception:
+ return 16
+
+
+def _clauses_use_vector_fast_path(
+ clauses: List[_NormalizedClause],
+) -> bool:
+ if not clauses:
+ return False
+ for c in clauses:
+ if c.condition is not None:
+ return False
+ for v in c.spec.values():
+ if callable(v):
+ return False
+ return True
+
+
+def _vectorized_matched_transform(
+ batch: pa.Table,
+ spec: Dict[str, Any],
+ on_pairs: Sequence[Tuple[str, str]],
+ update_cols: Sequence[str],
+ row_id_name: str,
+ update_schema: pa.Schema,
+) -> pa.Table:
+ available = set(batch.schema.names)
+ arrays: list = [batch.column(f"t.{row_id_name}")]
+ for col in update_cols:
+ out_type = update_schema.field(col).type
+ if col in spec:
+ arrays.append(_resolve_spec_array(spec[col], batch, available, on_pairs, out_type))
+ else:
+ arrays.append(batch.column(f"t.{col}"))
+ return pa.Table.from_arrays(arrays, schema=update_schema)
+
+
+def _vectorized_insert_transform(
+ batch: pa.Table,
+ spec: Dict[str, Any],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+) -> pa.Table:
+ available = set(batch.schema.names)
+ arrays: list = []
+ for col in target_field_names:
+ out_type = target_pa_schema.field(col).type
+ if col in spec:
+ arrays.append(_resolve_spec_array(spec[col], batch, available, (), out_type))
+ else:
+ arrays.append(pa.nulls(batch.num_rows, type=out_type))
+ return pa.Table.from_arrays(arrays, schema=target_pa_schema)
+
+
+def _resolve_spec_array(
+ val: Any,
+ batch: pa.Table,
+ available: set,
+ on_pairs: Sequence[Tuple[str, str]],
+ out_type: pa.DataType,
+):
+ if isinstance(val, str) and val.startswith("s."):
+ ref = val[2:]
+ if f"s.{ref}" in available:
+ return batch.column(f"s.{ref}")
+ # Equi-join drops the right-side join key; fall back to target's value.
+ for sk, tk in on_pairs:
+ if sk == ref and f"t.{tk}" in available:
+ return batch.column(f"t.{tk}")
+ return pa.nulls(batch.num_rows, type=out_type)
+ if isinstance(val, str) and val.startswith("t."):
+ ref = val[2:]
+ col_name = f"t.{ref}"
+ return batch.column(col_name) if col_name in available else pa.nulls(
+ batch.num_rows, type=out_type
+ )
+ return pa.array([val] * batch.num_rows, type=out_type)
+
+
+def _build_not_matched_insert_ds(
+ *,
+ target_identifier: str,
+ source_ds,
+ target_on: Sequence[str],
+ source_on: Sequence[str],
+ clauses: List[_NormalizedClause],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+ catalog_options: Dict[str, str],
+ num_partitions: int,
+ target_empty: bool = False,
+):
+ from pypaimon.ray.ray_paimon import read_paimon
+ from pypaimon.ray.shuffle import _coerce_large_string_types
+
+ captured_clauses = clauses
+ captured_field_names = list(target_field_names)
+ out_schema = target_pa_schema
+
+ source_schema = source_ds.schema()
+ source_cols = list(source_schema.names) if source_schema is not None else list(source_on)
+ source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols})
+
+ if target_empty:
+ unmatched = source_renamed
+ else:
+ target_ds = read_paimon(
+ target_identifier, catalog_options, projection=list(target_on)
+ )
+ target_renamed = target_ds.rename_columns(
+ {c: f"t.{c}" for c in target_on}
+ )
+ unmatched = source_renamed.join(
+ target_renamed,
+ join_type="left_anti",
+ num_partitions=num_partitions,
+ on=tuple(f"s.{c}" for c in source_on),
+ right_on=tuple(f"t.{c}" for c in target_on),
+ )
+
+ if _clauses_use_vector_fast_path(clauses):
+ first_spec = clauses[0].spec
+
+ def _fast(batch: pa.Table) -> pa.Table:
+ return _coerce_large_string_types(
+ _vectorized_insert_transform(
+ batch, first_spec, captured_field_names, out_schema
+ )
+ )
+
+ return unmatched.map_batches(_fast, batch_format="pyarrow")
+
+ def _transform(batch: pa.Table) -> pa.Table:
+ return _apply_insert_transform(
+ batch, captured_clauses, captured_field_names, out_schema
+ )
+
+ return unmatched.map_batches(_transform, batch_format="pyarrow")
+
+
+def _apply_insert_transform(
+ batch: pa.Table,
+ clauses: List[_NormalizedClause],
+ field_names: Sequence[str],
+ out_schema: pa.Schema,
+) -> pa.Table:
+ from pypaimon.ray.shuffle import _coerce_large_string_types
+
+ rows = batch.to_pylist()
+ out = []
+ for row in rows:
+ s_row = {k[2:]: v for k, v in row.items() if k.startswith("s.")}
+ combined = _prefixed(s_row, None)
+ for clause in clauses:
+ if clause.condition is not None and not clause.condition.eval(combined):
+ continue
+ out.append(
+ _apply_set(
+ clause.spec, s_row, None, field_names, null_unspecified=True
+ )
+ )
+ break
+ aligned = [{name: r.get(name) for name in field_names} for r in out]
+ return _coerce_large_string_types(pa.Table.from_pylist(aligned, schema=out_schema))
+
+
+def _build_unified_both(
+ *,
+ target_identifier: str,
+ source_ds,
+ target_on: Sequence[str],
+ source_on: Sequence[str],
+ matched_clauses: List[_NormalizedClause],
+ not_matched_clauses: List[_NormalizedClause],
+ target_field_names: Sequence[str],
+ target_pa_schema: pa.Schema,
+ update_cols: Sequence[str],
+ catalog_options: Dict[str, str],
+ num_partitions: int,
+):
+ import pyarrow.compute as pc
+
+ from pypaimon.ray.ray_paimon import read_paimon
+ from pypaimon.ray.shuffle import _coerce_large_string_types
+ from pypaimon.table.special_fields import SpecialFields
+
+ row_id_name = SpecialFields.ROW_ID.name
+
+ needed_cols = _resolve_target_projection(
+ matched_clauses, target_on, update_cols, target_field_names,
+ )
+ projection = [row_id_name] + [c for c in needed_cols if c != row_id_name]
+ target_ds = read_paimon(target_identifier, catalog_options, projection=projection)
+ target_renamed = target_ds.rename_columns(
+ {c: f"t.{c}" for c in target_ds.schema().names}
+ )
+ source_schema = source_ds.schema()
+ source_cols = list(source_schema.names) if source_schema is not None else list(source_on)
+ source_renamed = source_ds.rename_columns({c: f"s.{c}" for c in source_cols})
+
+ # One LEFT_OUTER join feeds both branches: rows with a non-null target side
+ # are matched (UPDATE), null target side means no key match (INSERT). The
+ # join shuffle is the dominant cost, so materialize once and route both ways
+ # instead of reading and shuffling the target table twice.
+ joined = source_renamed.join(
+ target_renamed,
+ join_type="left_outer",
+ num_partitions=num_partitions,
+ on=tuple(f"s.{c}" for c in source_on),
+ right_on=tuple(f"t.{c}" for c in target_on),
+ ).materialize()
+
+ t_row_id_col = f"t.{row_id_name}"
+ on_pairs = list(zip(source_on, target_on))
+ update_schema = _build_update_schema(target_pa_schema, update_cols, row_id_name)
+
+ use_fast_matched = _clauses_use_vector_fast_path(matched_clauses)
+ first_matched_spec = matched_clauses[0].spec if use_fast_matched else None
+ m_update_cols = list(update_cols)
+ m_field_names = list(target_field_names)
+
+ def _matched_batch(batch: pa.Table) -> pa.Table:
+ sub = batch.filter(pc.is_valid(batch.column(t_row_id_col)))
+ if use_fast_matched:
+ return _vectorized_matched_transform(
+ sub, first_matched_spec, on_pairs, m_update_cols,
+ row_id_name, update_schema,
+ )
+ return _apply_matched_transform(
+ sub, matched_clauses, on_pairs, m_update_cols,
+ m_field_names, row_id_name, update_schema,
+ )
+
+ update_ds = joined.map_batches(_matched_batch, batch_format="pyarrow")
+
+ i_field_names = list(target_field_names)
+ use_fast_insert = _clauses_use_vector_fast_path(not_matched_clauses)
+ first_insert_spec = not_matched_clauses[0].spec if use_fast_insert else None
+
+ def _insert_batch(batch: pa.Table) -> pa.Table:
+ sub = batch.filter(pc.is_null(batch.column(t_row_id_col)))
+ if use_fast_insert:
+ return _coerce_large_string_types(
+ _vectorized_insert_transform(
+ sub, first_insert_spec, i_field_names, target_pa_schema
+ )
+ )
+ return _apply_insert_transform(
+ sub, not_matched_clauses, i_field_names, target_pa_schema
+ )
+
+ insert_ds = joined.map_batches(_insert_batch, batch_format="pyarrow")
+
+ return update_ds, insert_ds
+
+
+def _distributed_write_collect_msgs(
+ insert_ds,
+ table,
+ *,
+ ray_remote_args: Optional[Dict[str, Any]],
+ concurrency: Optional[int],
+) -> list:
+ from pypaimon.write.ray_datasink import PaimonDatasink
+
+ class _CollectingDatasink(PaimonDatasink):
+ def __init__(self, t):
+ super().__init__(t, overwrite=False)
+ self.collected: list = []
+
+ def on_write_complete(self, write_result):
+ if hasattr(write_result, "write_returns"):
+ write_returns = write_result.write_returns
+ elif isinstance(write_result, list):
+ write_returns = write_result
+ else:
+ raise TypeError(
+ f"Unexpected write_result type {type(write_result).__name__}"
+ )
+ self.collected = [
+ m
+ for batch in write_returns
+ for m in batch
+ if not m.is_empty()
+ ]
+
+ sink = _CollectingDatasink(table)
+ write_kwargs: Dict[str, Any] = {}
+ if ray_remote_args is not None:
+ write_kwargs["ray_remote_args"] = ray_remote_args
+ if concurrency is not None:
+ write_kwargs["concurrency"] = concurrency
+ insert_ds.write_datasink(sink, **write_kwargs)
+ return sink.collected
+
+
+def _apply_global_index_update_action(
+ table, snapshot, update_cols: Sequence[str], update_msgs, action: str
+) -> list:
+ """Handle updates touching globally-indexed columns, mirroring Spark's
+ ``checkUpdateResult``.
+
+ Scoped exactly like Spark: only index entries whose partition was written
+ by the update *and* whose indexed column is among the updated columns are
+ affected. THROW_ERROR (default) raises; DROP_PARTITION_INDEX drops those
+ entries (returned as index-delete commit messages, rebuild afterwards).
+ Like Spark, the INSERT path is left untouched.
+ """
+ if snapshot is None or not update_cols or not update_msgs:
+ return []
+ entries = _scan_global_index_entries(table, snapshot)
+ if not entries:
+ return []
+ field_by_id = {f.id: f.name for f in table.fields}
+ update_set = set(update_cols)
+ affected_partitions = {tuple(m.partition) for m in update_msgs}
+ affected = [
+ e for e in entries
+ if field_by_id.get(e.index_file.global_index_meta.index_field_id) in update_set
+ and tuple(e.partition.values) in affected_partitions
+ ]
+ if not affected:
+ return []
+ if action == GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX:
+ return _build_index_delete_msgs(affected)
+ conflicted = sorted(
+ {field_by_id.get(e.index_file.global_index_meta.index_field_id) for e in affected}
+ )
+ raise NotImplementedError(
+ f"MERGE INTO would update columns {conflicted} that have a global "
+ f"index; not supported (refusing to leave the index stale). Set "
+ f"'global-index.column-update-action' = 'DROP_PARTITION_INDEX' to drop "
+ f"the affected index instead."
+ )
+
+
+def _build_index_delete_msgs(entries) -> list:
+ """Group scanned index entries by partition into index-delete messages."""
+ from pypaimon.manifest.index_manifest_entry import IndexManifestEntry
+ from pypaimon.write.commit_message import CommitMessage
+
+ by_partition: Dict[tuple, list] = {}
+ for e in entries:
+ key = tuple(e.partition.values)
+ by_partition.setdefault(key, []).append(
+ IndexManifestEntry(
+ kind=1, partition=e.partition, bucket=e.bucket, index_file=e.index_file
+ )
+ )
+ return [
+ CommitMessage(partition=key, bucket=0, new_files=[], index_files=dels)
+ for key, dels in by_partition.items()
+ ]
+
+
+def _scan_global_index_entries(table, snapshot):
+ from pypaimon.index.index_file_handler import IndexFileHandler
+
+ handler = IndexFileHandler(table=table)
+ return handler.scan(
+ snapshot, lambda e: e.index_file.global_index_meta is not None
+ )
+
+
+def _require_ray_join() -> None:
+ """merge_into relies on ``Dataset.join`` (ray>=2.50). Read/sink users on
+ older ray are unaffected unless they call this, so check only here."""
+ import ray
+ from ray.data import Dataset
+
+ if not hasattr(Dataset, "join"):
+ raise RuntimeError(
+ f"merge_into requires ray>=2.50 (Dataset.join); "
+ f"installed ray is {ray.__version__}."
+ )
+
+
+def _reject_blob_updates(table, update_cols: set) -> None:
+ blob_cols = [
+ f.name
+ for f in table.table_schema.fields
+ if f.name in update_cols and getattr(f.type, "type", None) == "BLOB"
+ ]
+ if blob_cols:
+ raise NotImplementedError(
+ f"merge_into cannot update blob columns {blob_cols}; "
+ f"the row-id rewrite path skips .blob files."
+ )
+
+
+def _union_update_cols(clauses: List[_NormalizedClause]) -> List[str]:
+ seen: List[str] = []
+ seen_set: set = set()
+ for clause in clauses:
+ for col in clause.spec.keys():
+ if col not in seen_set:
+ seen.append(col)
+ seen_set.add(col)
+ return seen
+
+
+def _needed_target_cols(
+ clauses: List[_NormalizedClause],
+ on: Sequence[str],
+ update_cols: Sequence[str],
+ all_target_cols: Sequence[str],
+) -> list:
+ # Target needs only: join keys, t.col refs, and cols that may fall back
+ # (not set by every clause). Cols all clauses set from source aren't read.
+ needed = set(on)
+ set_by_all = set(update_cols)
+ for clause in clauses:
+ for value in clause.spec.values():
+ if callable(value):
+ return list(all_target_cols)
+ if isinstance(value, str) and value.startswith("t."):
+ needed.add(value[2:])
+ set_by_all &= set(clause.spec.keys())
+ needed |= set(update_cols) - set_by_all
+ return [c for c in all_target_cols if c in needed]
+
+
+def _resolve_target_projection(
+ clauses: List[_NormalizedClause],
+ target_on: Sequence[str],
+ update_cols: Sequence[str],
+ target_field_names: Sequence[str],
+) -> list:
+ # Precise: SET-side needs plus the target columns each parsed condition
+ # references. Anything not referenced (e.g. blob) is never read.
+ needed = set(_needed_target_cols(clauses, target_on, update_cols, target_field_names))
+ target_set = set(target_field_names)
+ for clause in clauses:
+ if clause.condition is not None:
+ needed |= clause.condition.target_columns() & target_set
+ return [c for c in target_field_names if c in needed]
+
+
+def _normalize_set_spec(
+ spec: SetSpec,
+ target_field_names: Sequence[str],
+ on_map: Optional[Mapping[str, str]] = None,
+) -> Dict[str, Any]:
+ on_map = on_map or {}
+ if isinstance(spec, str):
+ if spec != "*":
+ raise ValueError(
+ f"SET spec strings other than '*' are not supported; got {spec!r}."
+ )
+ # A renamed ON key resolves via the source's ON column, not its own name.
+ return {col: f"s.{on_map.get(col, col)}" for col in target_field_names}
+ if not isinstance(spec, dict):
+ raise ValueError(
+ f"SET spec must be '*' or a dict, got {type(spec).__name__}."
+ )
+ target_set = set(target_field_names)
+ for col in spec:
+ if col not in target_set:
+ raise ValueError(
+ f"SET key '{col}' is not a column of the target table "
+ f"(columns: {list(target_field_names)})."
+ )
+ return dict(spec)
+
+
+def _normalize_source(source: Any, catalog_options: Dict[str, str]):
+ import ray.data
+
+ if isinstance(source, ray.data.Dataset):
+ return source
+ if isinstance(source, str):
+ from pypaimon.ray.ray_paimon import read_paimon
+ return read_paimon(source, catalog_options)
+ if isinstance(source, pa.Table):
+ return ray.data.from_arrow(source)
+ try:
+ import pandas as pd
+ except ImportError:
+ pd = None
+ if pd is not None and isinstance(source, pd.DataFrame):
+ return ray.data.from_pandas(source)
+ raise TypeError(
+ "source must be a ray.data.Dataset, a Paimon table identifier string, "
+ f"a pyarrow.Table, or a pandas.DataFrame; got {type(source).__name__}."
+ )
+
+
+def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None:
+ schema = source_ds.schema()
+ if schema is None:
+ return
+ names = set(schema.names)
+ missing = [c for c in on if c not in names]
+ if missing:
+ raise ValueError(
+ f"'on' columns {missing} missing from source schema {list(names)}."
+ )
+
+
+def _apply_set(
+ spec: Dict[str, Any],
+ s_row: Optional[Dict[str, Any]],
+ t_row: Optional[Dict[str, Any]],
+ target_field_names: Sequence[str],
+ null_unspecified: bool = False,
+) -> Dict[str, Any]:
+ combined = _prefixed(s_row, t_row)
+ if t_row is not None:
+ base = t_row
+ elif s_row is not None and not null_unspecified:
+ base = s_row
+ else:
+ base = {}
+ out: Dict[str, Any] = {}
+ for col in target_field_names:
+ if col in spec:
+ out[col] = _eval_set_value(spec[col], combined, s_row, t_row)
+ elif col in base:
+ out[col] = base[col]
+ else:
+ out[col] = None
+ return out
+
+
+def _prefixed(
+ s_row: Optional[Dict[str, Any]], t_row: Optional[Dict[str, Any]]
+) -> Dict[str, Any]:
+ out: Dict[str, Any] = {}
+ if s_row is not None:
+ for k, v in s_row.items():
+ out[f"s.{k}"] = v
+ if t_row is not None:
+ for k, v in t_row.items():
+ out[f"t.{k}"] = v
+ return out
+
+
+def _eval_set_value(
+ value: Any,
+ combined: Mapping[str, Any],
+ s_row: Optional[Dict[str, Any]],
+ t_row: Optional[Dict[str, Any]],
+) -> Any:
+ if callable(value):
+ return value(combined)
+ if isinstance(value, str):
+ if value.startswith("s.") and s_row is not None:
+ return s_row.get(value[2:])
+ if value.startswith("t.") and t_row is not None:
+ return t_row.get(value[2:])
+ return value
diff --git a/paimon-python/pypaimon/tests/index_manifest_write_test.py b/paimon-python/pypaimon/tests/index_manifest_write_test.py
new file mode 100644
index 000000000000..7107fe2fa59e
--- /dev/null
+++ b/paimon-python/pypaimon/tests/index_manifest_write_test.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import shutil
+import tempfile
+import unittest
+import uuid
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.globalindex.global_index_meta import GlobalIndexMeta
+from pypaimon.index.index_file_meta import IndexFileMeta
+from pypaimon.manifest.index_manifest_entry import IndexManifestEntry
+from pypaimon.manifest.index_manifest_file import IndexManifestFile
+from pypaimon.table.row.generic_row import GenericRow
+
+
+class IndexManifestWriteTest(unittest.TestCase):
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('vec', pa.string()),
+ ])
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+ cls.catalog.create_database('default', True)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _table(self):
+ name = f'default.idx_{uuid.uuid4().hex[:8]}'
+ s = Schema.from_pyarrow_schema(self.pa_schema)
+ self.catalog.create_table(name, s, False)
+ return self.catalog.get_table(name)
+
+ def _entry(self, file_name, field_id, meta=b'm'):
+ partition = GenericRow([], [])
+ index_file = IndexFileMeta(
+ index_type='BTREE',
+ file_name=file_name,
+ file_size=123,
+ row_count=10,
+ global_index_meta=GlobalIndexMeta(
+ row_range_start=0,
+ row_range_end=10,
+ index_field_id=field_id,
+ extra_field_ids=[field_id + 1],
+ index_meta=meta,
+ ),
+ )
+ return IndexManifestEntry(kind=0, partition=partition, bucket=0, index_file=index_file)
+
+ def test_write_read_roundtrip(self):
+ imf = IndexManifestFile(self._table())
+ name = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)])
+ out = imf.read(name)
+ self.assertEqual(2, len(out))
+ by_name = {e.index_file.file_name: e for e in out}
+ a = by_name['idx-a']
+ self.assertEqual('BTREE', a.index_file.index_type)
+ self.assertEqual(123, a.index_file.file_size)
+ self.assertEqual(10, a.index_file.row_count)
+ self.assertEqual(0, a.kind)
+ gim = a.index_file.global_index_meta
+ self.assertEqual(1, gim.index_field_id)
+ self.assertEqual(0, gim.row_range_start)
+ self.assertEqual(10, gim.row_range_end)
+ self.assertEqual([2], gim.extra_field_ids)
+ self.assertEqual(b'm', bytes(gim.index_meta))
+
+ def test_combine_drops_named_files(self):
+ imf = IndexManifestFile(self._table())
+ previous = imf.write([self._entry('idx-a', 1), self._entry('idx-b', 2)])
+ deletes = [self._entry('idx-a', 1)]
+ new_name = imf.combine(previous, deletes)
+ self.assertNotEqual(previous, new_name)
+ survivors = {e.index_file.file_name for e in imf.read(new_name)}
+ self.assertEqual({'idx-b'}, survivors)
+
+ def test_combine_unknown_delete_is_noop_on_content(self):
+ imf = IndexManifestFile(self._table())
+ previous = imf.write([self._entry('idx-a', 1)])
+ new_name = imf.combine(previous, [self._entry('idx-zzz', 9)])
+ survivors = {e.index_file.file_name for e in imf.read(new_name)}
+ self.assertEqual({'idx-a'}, survivors)
+
+ def test_combine_empty_deletes_returns_previous(self):
+ imf = IndexManifestFile(self._table())
+ previous = imf.write([self._entry('idx-a', 1)])
+ self.assertEqual(previous, imf.combine(previous, []))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
new file mode 100644
index 000000000000..6479c921056f
--- /dev/null
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -0,0 +1,1043 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import shutil
+import tempfile
+import unittest
+import uuid
+
+import pyarrow as pa
+import ray
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into
+
+
+class RayDataEvolutionMergeIntoTest(unittest.TestCase):
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+
+ de_options = {
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ }
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog_options = {'warehouse': cls.warehouse}
+ cls.catalog = CatalogFactory.create(cls.catalog_options)
+ cls.catalog.create_database('default', True)
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, num_cpus=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ if ray.is_initialized():
+ ray.shutdown()
+ except Exception:
+ pass
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _create_table(self, options=None):
+ opts = options if options is not None else self.de_options
+ name = f'default.tbl_{uuid.uuid4().hex[:8]}'
+ s = Schema.from_pyarrow_schema(self.pa_schema, options=opts)
+ self.catalog.create_table(name, s, False)
+ return name
+
+ def _source(self, ids=(1,)):
+ return pa.Table.from_pydict(
+ {
+ 'id': pa.array(list(ids), type=pa.int32()),
+ 'name': ['x'] * len(ids),
+ 'age': [10] * len(ids),
+ },
+ schema=self.pa_schema,
+ )
+
+ def _write(self, target, data):
+ table = self.catalog.get_table(target)
+ wb = table.new_batch_write_builder()
+ writer = wb.new_write()
+ writer.write_arrow(data)
+ wb.new_commit().commit(writer.prepare_commit())
+ writer.close()
+
+ def _read_sorted(self, target):
+ table = self.catalog.get_table(target)
+ rb = table.new_read_builder()
+ splits = rb.new_scan().plan().splits()
+ return rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
+
+ def _snapshot_id(self, target):
+ table = self.catalog.get_table(target)
+ snap = table.snapshot_manager().get_latest_snapshot()
+ return snap.id if snap is not None else None
+
+ def test_no_clause_raises(self):
+ target = self._create_table()
+ with self.assertRaises(ValueError):
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ )
+
+ def test_non_de_table_rejected(self):
+ target = self._create_table(options={'row-tracking.enabled': 'true'})
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+ self.assertIn('data-evolution.enabled', str(ctx.exception))
+
+ def test_no_row_tracking_rejected(self):
+ target = self._create_table(options={'data-evolution.enabled': 'true'})
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+ self.assertIn('row-tracking.enabled', str(ctx.exception))
+
+ def test_source_missing_on_col_raises(self):
+ target = self._create_table()
+ bad_source = pa.Table.from_pydict(
+ {'name': ['x'], 'age': [10]},
+ schema=pa.schema([('name', pa.string()), ('age', pa.int32())]),
+ )
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=bad_source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+ self.assertIn("'id'", str(ctx.exception))
+
+ def test_matched_update_star(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b2', 'c2', 'd'],
+ 'age': pa.array([22, 33, 40], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c2'])
+ self.assertEqual(out['age'], [10, 22, 33])
+
+ def test_matched_update_dict(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2], type=pa.int32()),
+ 'name': ['ignored'],
+ 'age': pa.array([99], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update={'age': 's.age'})],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2])
+ self.assertEqual(out['name'], ['a', 'b'])
+ self.assertEqual(out['age'], [10, 99])
+
+ def test_matched_update_with_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([5, 100, 50], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(
+ update={'age': 's.age'},
+ condition="s.age > t.age",
+ ),
+ ],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['age'], [10, 100, 50])
+
+ def test_invalid_condition_expression_raises(self):
+ with self.assertRaises(ValueError):
+ merge_into(
+ target=self._create_table(),
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update={'age': 's.age'},
+ condition="evil(1)")],
+ )
+
+ def test_matched_multiple_clauses_first_match_wins(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['s1', 's2', 's3'],
+ 'age': pa.array([5, 25, 100], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(
+ update={'age': 1},
+ condition="s.age < t.age",
+ ),
+ WhenMatched(update={'age': 999}),
+ ],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['age'], [1, 999, 999])
+
+ def test_matched_partial_clause_falls_back_to_target(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['old'],
+ 'age': pa.array([42], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['new'],
+ 'age': pa.array([99], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update={'name': 's.name'}),
+ WhenMatched(update={'age': 's.age'}),
+ ],
+ )
+ out = self._read_sorted(target)
+ self.assertEqual(out['name'], ['new'])
+ self.assertEqual(out['age'], [42])
+
+ def test_not_matched_insert_appends_unmatched(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b2', 'c2', 'd'],
+ 'age': pa.array([22, 33, 40], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3, 4])
+ self.assertEqual(out['name'], ['a', 'b', 'c', 'd'])
+ self.assertEqual(out['age'], [10, 20, 30, 40])
+
+ def test_not_matched_insert_with_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b', 'c', 'd'],
+ 'age': pa.array([5, 50, 100], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(
+ insert='*',
+ condition="s.age >= 50",
+ ),
+ ],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 3, 4])
+ self.assertEqual(out['age'], [10, 50, 100])
+
+ def test_not_matched_multiple_clauses_first_match_wins(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b', 'c'],
+ 'age': pa.array([5, 99], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(
+ insert={'id': 's.id', 'name': 'small', 'age': 1},
+ condition="s.age < 10",
+ ),
+ WhenNotMatched(insert={'id': 's.id', 'name': 'big', 'age': 2}),
+ ],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'small', 'big'])
+ self.assertEqual(out['age'], [10, 1, 2])
+
+ def test_combined_update_and_insert(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ metrics = merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c'])
+ self.assertEqual(out['age'], [10, 22, 30])
+ self.assertEqual(metrics, {'num_updated': 1, 'num_inserted': 1})
+
+ def test_combined_matched_clause_condition_no_merge_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['n1', 'n2', 'n3'],
+ 'age': pa.array([100, 5, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(
+ update={'name': 's.name'},
+ condition="s.age > 50",
+ )
+ ],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ rows = sorted(zip(out['id'], out['name'], out['age']))
+ self.assertEqual(rows, [(1, 'n1', 10), (2, 'b', 20), (3, 'n3', 30)])
+
+ def test_on_with_renamed_columns(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source_schema = pa.schema([
+ ('uid', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+ source = pa.Table.from_pydict(
+ {
+ 'uid': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=source_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on={'id': 'uid'},
+ when_matched=[WhenMatched(update={'age': 's.age'})],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2])
+ self.assertEqual(out['age'], [10, 22])
+
+ def test_on_with_renamed_columns_star(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source_schema = pa.schema([
+ ('uid', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+ source = pa.Table.from_pydict(
+ {
+ 'uid': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=source_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on={'id': 'uid'},
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c'])
+ self.assertEqual(out['age'], [10, 22, 30])
+
+ def test_insert_into_empty_target(self):
+ target = self._create_table()
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b', 'c'])
+ self.assertEqual(out['age'], [10, 20, 30])
+
+ def test_insert_dict_fills_unspecified_with_null(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2], type=pa.int32()),
+ 'name': ['source-name'],
+ 'age': pa.array([99], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[WhenNotMatched(insert={'id': 's.id', 'age': 99})],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2])
+ self.assertEqual(out['name'], ['a', None])
+ self.assertEqual(out['age'], [10, 99])
+
+ def test_multi_source_match_raises_by_default(self):
+ # One target row matched by several source rows: the winning value is
+ # undefined (Spark DE's checkCardinality=false), so we refuse by default.
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([100, 200], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ with self.assertRaises(Exception) as ctx:
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+ self.assertIn("multiple source rows", str(ctx.exception))
+
+ def test_multi_source_match_allow_keeps_first(self):
+ # Opt-in: allow_multiple_matches keeps the first match deterministically.
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([100, 200], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ allow_multiple_matches=True,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1])
+ # One source row wins; we don't pin which.
+ self.assertIn(out['name'][0], ['x', 'y'])
+ self.assertIn(out['age'][0], [100, 200])
+
+ def test_blob_update_is_rejected(self):
+ import types
+
+ from pypaimon.ray.data_evolution_merge_into import _reject_blob_updates
+ from pypaimon.schema.data_types import AtomicType, DataField
+
+ fake_table = types.SimpleNamespace(
+ table_schema=types.SimpleNamespace(
+ fields=[
+ DataField(0, 'id', AtomicType('INT')),
+ DataField(1, 'payload', AtomicType('BLOB')),
+ ]
+ )
+ )
+ with self.assertRaises(NotImplementedError):
+ _reject_blob_updates(fake_table, {'payload'})
+ _reject_blob_updates(fake_table, {'id'})
+
+ def test_combined_writes_single_snapshot(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+ before = self._snapshot_id(target)
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3], type=pa.int32()),
+ 'name': ['b2', 'c'],
+ 'age': pa.array([22, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ when_not_matched=[WhenNotMatched(insert='*')],
+ )
+
+ after = self._snapshot_id(target)
+ self.assertEqual(after, before + 1)
+
+ def test_self_merge_via_normal_join(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ merge_into(
+ target=target,
+ source=target,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(
+ update={'age': lambda r: r['t.age'] + 1},
+ ),
+ ],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['age'], [11, 21, 31])
+
+ def test_matched_update_can_change_on_column(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['x'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['y'],
+ 'age': pa.array([20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update={'id': 999, 'name': 'y'})],
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [999])
+ self.assertEqual(out['name'], ['y'])
+ self.assertEqual(out['age'], [10])
+
+ def test_empty_target_matched_update_is_noop(self):
+ target = self._create_table()
+ before = self._snapshot_id(target)
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*')],
+ )
+
+ self.assertEqual(self._snapshot_id(target), before)
+
+
+class RayMergeIntoGlobalIndexGateTest(unittest.TestCase):
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('name', pa.string()),
+ ('age', pa.int32()),
+ ])
+
+ de_options = {
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ }
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+ cls.catalog.create_database('default', True)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _table(self):
+ name = f'default.gidx_{uuid.uuid4().hex[:8]}'
+ s = Schema.from_pyarrow_schema(self.pa_schema, options=self.de_options)
+ self.catalog.create_table(name, s, False)
+ return self.catalog.get_table(name)
+
+ def _entry(self, table, column, partition_values=()):
+ from pypaimon.globalindex.global_index_meta import GlobalIndexMeta
+ from pypaimon.index.index_file_meta import IndexFileMeta
+ from pypaimon.manifest.index_manifest_entry import IndexManifestEntry
+ from pypaimon.table.row.generic_row import GenericRow
+
+ field_id = next(f.id for f in table.fields if f.name == column)
+ index_file = IndexFileMeta(
+ index_type='BTREE', file_name=f'idx-{column}', file_size=1, row_count=1,
+ global_index_meta=GlobalIndexMeta(
+ row_range_start=0, row_range_end=1, index_field_id=field_id,
+ ),
+ )
+ return IndexManifestEntry(
+ kind=0, partition=GenericRow(list(partition_values), []),
+ bucket=0, index_file=index_file,
+ )
+
+ def _update_msg(self, partition=()):
+ from pypaimon.write.commit_message import CommitMessage
+ return CommitMessage(partition=partition, bucket=0, new_files=[])
+
+ def test_update_throw_error_raises(self):
+ from unittest.mock import patch
+ from pypaimon.ray import data_evolution_merge_into as m
+
+ table = self._table()
+ with patch.object(m, '_scan_global_index_entries',
+ return_value=[self._entry(table, 'age')]):
+ with self.assertRaises(NotImplementedError):
+ m._apply_global_index_update_action(
+ table, object(), ['age'], [self._update_msg()],
+ m.GLOBAL_INDEX_ACTION_THROW_ERROR,
+ )
+
+ def test_update_drop_returns_delete_msgs(self):
+ from unittest.mock import patch
+ from pypaimon.ray import data_evolution_merge_into as m
+
+ table = self._table()
+ with patch.object(m, '_scan_global_index_entries',
+ return_value=[self._entry(table, 'age')]):
+ msgs = m._apply_global_index_update_action(
+ table, object(), ['age'], [self._update_msg()],
+ m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX,
+ )
+ self.assertEqual(1, len(msgs))
+ self.assertFalse(msgs[0].is_empty())
+ self.assertEqual('idx-age', msgs[0].index_files[0].index_file.file_name)
+ self.assertEqual(1, msgs[0].index_files[0].kind)
+
+ def test_update_unaffected_column_is_noop(self):
+ from unittest.mock import patch
+ from pypaimon.ray import data_evolution_merge_into as m
+
+ table = self._table()
+ with patch.object(m, '_scan_global_index_entries',
+ return_value=[self._entry(table, 'age')]):
+ msgs = m._apply_global_index_update_action(
+ table, object(), ['name'], [self._update_msg()],
+ m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX,
+ )
+ self.assertEqual([], msgs)
+
+ def test_update_untouched_partition_is_noop(self):
+ from unittest.mock import patch
+ from pypaimon.ray import data_evolution_merge_into as m
+
+ table = self._table()
+ entry = self._entry(table, 'age', partition_values=('EU',))
+ with patch.object(m, '_scan_global_index_entries', return_value=[entry]):
+ msgs = m._apply_global_index_update_action(
+ table, object(), ['age'], [self._update_msg(partition=('US',))],
+ m.GLOBAL_INDEX_ACTION_DROP_PARTITION_INDEX,
+ )
+ self.assertEqual([], msgs)
+
+
+class TargetProjectionTest(unittest.TestCase):
+
+ def _clause(self, spec, condition=None):
+ from pypaimon.ray import data_evolution_merge_into as m
+ from pypaimon.ray.condition_expr import parse
+ return m._NormalizedClause(
+ spec=spec, condition=parse(condition) if condition else None
+ )
+
+ def test_unconditional_set_excludes_target_update_col(self):
+ from pypaimon.ray import data_evolution_merge_into as m
+ cols = m._resolve_target_projection(
+ [self._clause({'feature': 's.feature'})],
+ ['id'], ['feature'], ['id', 'feature', 'image'],
+ )
+ self.assertEqual(['id'], cols)
+
+ def test_condition_columns_are_projected(self):
+ from pypaimon.ray import data_evolution_merge_into as m
+ cols = m._resolve_target_projection(
+ [self._clause({'feature': 's.feature'}, condition="t.age > 0")],
+ ['id'], ['feature'], ['id', 'age', 'feature', 'image'],
+ )
+ self.assertEqual(['id', 'age'], cols)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/write/commit_message.py b/paimon-python/pypaimon/write/commit_message.py
index 7bce06d8ab13..db6d20ff1fae 100644
--- a/paimon-python/pypaimon/write/commit_message.py
+++ b/paimon-python/pypaimon/write/commit_message.py
@@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-from dataclasses import dataclass
-from typing import List, Tuple, Optional
+from dataclasses import dataclass, field
+from typing import List, Tuple, Optional, TYPE_CHECKING
from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+if TYPE_CHECKING:
+ from pypaimon.manifest.index_manifest_entry import IndexManifestEntry
+
@dataclass
class CommitMessage:
@@ -27,6 +30,7 @@ class CommitMessage:
bucket: int
new_files: List[DataFileMeta]
check_from_snapshot: Optional[int] = -1
+ index_files: List['IndexManifestEntry'] = field(default_factory=list)
def is_empty(self):
- return not self.new_files
+ return not self.new_files and not self.index_files
diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py
index 486e28924014..93f0ec82a592 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -142,6 +142,10 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int):
logger.info("Finished collecting changes, including: %d entries", len(commit_entries))
+ index_deletes = []
+ for msg in commit_messages:
+ index_deletes.extend(msg.index_files)
+
commit_kind = "APPEND"
detect_conflicts = False
allow_rollback = False
@@ -157,7 +161,8 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int):
commit_identifier=commit_identifier,
commit_entries_plan=lambda snapshot: commit_entries,
detect_conflicts=detect_conflicts,
- allow_rollback=allow_rollback)
+ allow_rollback=allow_rollback,
+ index_deletes=index_deletes)
def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int):
"""Commit the given commit messages in overwrite mode."""
@@ -243,7 +248,7 @@ def truncate_table(self, commit_identifier: int) -> None:
)
def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan,
- detect_conflicts=False, allow_rollback=False):
+ detect_conflicts=False, allow_rollback=False, index_deletes=None):
retry_count = 0
retry_result = None
@@ -254,7 +259,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan,
# No entries to commit (e.g. drop_partitions with no matching data): skip commit
# to avoid creating manifest/snapshot with empty partition_stats (causes read errors).
- if not commit_entries:
+ if not commit_entries and not index_deletes:
break
result = self._try_commit_once(
@@ -265,6 +270,7 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan,
latest_snapshot=latest_snapshot,
detect_conflicts=detect_conflicts,
allow_rollback=allow_rollback,
+ index_deletes=index_deletes,
)
if result.is_success():
@@ -316,7 +322,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str
commit_entries: List[ManifestEntry], commit_identifier: int,
latest_snapshot: Optional[Snapshot],
detect_conflicts: bool = False,
- allow_rollback: bool = False) -> CommitResult:
+ allow_rollback: bool = False,
+ index_deletes=None) -> CommitResult:
start_millis = int(time.time() * 1000)
if self._is_duplicate_commit(retry_result, latest_snapshot, commit_identifier, commit_kind):
return SuccessResult()
@@ -327,6 +334,7 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str
# process new_manifest
new_manifest_file = f"manifest-{str(uuid.uuid4())}-0"
+ new_index_manifest = None
# process snapshot
new_snapshot_id = latest_snapshot.id + 1 if latest_snapshot else 1
@@ -384,6 +392,13 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str
index_manifest = None
if latest_snapshot and commit_kind == "APPEND":
index_manifest = latest_snapshot.index_manifest
+ if index_deletes:
+ from pypaimon.manifest.index_manifest_file import IndexManifestFile
+ previous_index_manifest = index_manifest
+ index_manifest = IndexManifestFile(self.table).combine(
+ previous_index_manifest, index_deletes)
+ if index_manifest != previous_index_manifest:
+ new_index_manifest = index_manifest
snapshot_data = Snapshot(
version=3,
@@ -403,7 +418,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str
# Generate partition statistics for the commit
statistics = self._generate_partition_statistics(commit_entries)
except Exception as e:
- self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list)
+ self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list,
+ new_index_manifest)
logger.warning(f"Exception occurs when preparing snapshot: {e}", exc_info=True)
raise RuntimeError(f"Failed to prepare snapshot: {e}")
@@ -423,7 +439,8 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str
commit_kind,
commit_time_s,
)
- self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list)
+ self._cleanup_preparation_failure(delta_manifest_list, base_manifest_list,
+ new_index_manifest)
return RetryResult(latest_snapshot, None)
except Exception as e:
# Commit exception, not sure about the situation and should not clean up the files
@@ -604,10 +621,14 @@ def _commit_retry_wait(self, retry_count: int):
def _cleanup_preparation_failure(self,
delta_manifest_list: Optional[str],
- base_manifest_list: Optional[str]):
+ base_manifest_list: Optional[str],
+ index_manifest: Optional[str] = None):
try:
manifest_path = self.manifest_list_manager.manifest_path
+ if index_manifest:
+ self.table.file_io.delete_quietly(f"{manifest_path}/{index_manifest}")
+
if delta_manifest_list:
manifest_files = self.manifest_list_manager.read(delta_manifest_list)
for manifest_meta in manifest_files:
diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py
index fe2fb9a64b79..55b755e2d4d5 100644
--- a/paimon-python/pypaimon/write/table_update.py
+++ b/paimon-python/pypaimon/write/table_update.py
@@ -109,8 +109,6 @@ def with_update_type(self, update_cols: List[str]):
for col in update_cols:
if col not in self.table.field_names:
raise ValueError(f"Column {col} is not in table schema.")
- if len(update_cols) == len(self.table.field_names):
- update_cols = None
self.update_cols = update_cols
return self
diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py
index ac9c68c3623b..ab61ed21505e 100644
--- a/paimon-python/pypaimon/write/table_update_by_row_id.py
+++ b/paimon-python/pypaimon/write/table_update_by_row_id.py
@@ -42,19 +42,30 @@ class TableUpdateByRowId:
FIRST_ROW_ID_COLUMN = '_FIRST_ROW_ID'
- def __init__(self, table, commit_user: str, commit_identifier: int):
+ def __init__(
+ self, table, commit_user: str, commit_identifier: int,
+ precomputed_files_info: Optional[Tuple[
+ int, List[int],
+ Dict[int, Tuple[DataSplit, List[DataFileMeta]]],
+ int,
+ ]] = None,
+ ):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
self.commit_user = commit_user
self.commit_identifier = commit_identifier
- # Snapshot the current state once: a single ``first_row_id -> (split, files)``
- # map is enough to drive every downstream lookup (partition, row-count, read).
- (self.snapshot_id,
- self.first_row_ids,
- self._first_row_id_index,
- self.total_row_count) = self._load_existing_files_info()
+ if precomputed_files_info is not None:
+ (self.snapshot_id,
+ self.first_row_ids,
+ self._first_row_id_index,
+ self.total_row_count) = precomputed_files_info
+ else:
+ (self.snapshot_id,
+ self.first_row_ids,
+ self._first_row_id_index,
+ self.total_row_count) = self._load_existing_files_info()
self.commit_messages: List[CommitMessage] = []