From f5bf967e3757a813d53230b87c2c28000a533067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Thu, 28 May 2026 15:35:14 +0800 Subject: [PATCH] [python] supports chunk shuffle in file meta layer. --- .../scanner/chunk_shuffle_split_generator.py | 374 +++++++++ .../pypaimon/read/scanner/file_scanner.py | 66 +- paimon-python/pypaimon/read/table_scan.py | 4 + .../chunk_shuffle_split_generator_test.py | 754 ++++++++++++++++++ 4 files changed, 1197 insertions(+), 1 deletion(-) create mode 100644 paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py create mode 100644 paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py diff --git a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py new file mode 100644 index 000000000000..8ffb9a066971 --- /dev/null +++ b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py @@ -0,0 +1,374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import random +from abc import abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +from pypaimon.globalindex.indexed_split import IndexedSplit +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.read.scanner.split_generator import AbstractSplitGenerator +from pypaimon.read.sliced_split import SlicedSplit +from pypaimon.read.split import DataSplit, Split +from pypaimon.table.row.generic_row import GenericRow +from pypaimon.utils.range import Range + + +def _null_safe_partition_key(partition_values) -> tuple: + """Wrap each partition value with a None-aware tag so tuples that mix + null and non-null partition values can be ordered without raising + ``TypeError: '<' not supported between instances of 'NoneType' and 'str'``. + Paimon supports null partition values; Python 3 refuses to compare + None against str/int directly. + """ + return tuple((v is None, v) for v in partition_values) + + +@dataclass +class _Chunk: + """A unit of work for one DataLoader read. ``segments`` carries + subclass-specific payload (file segments for append, aligned-group + segments for data evolution). + """ + partition: GenericRow + bucket: int + segments: List[Any] + + +class ChunkShuffleSplitGeneratorBase(AbstractSplitGenerator): + """Common scaffolding for deterministic chunk-shuffled split generation. + + Pipeline (template method, in :meth:`create_splits`): + 1. Stable-sort entries (key from :meth:`_sort_key`) so manifest-read + parallelism cannot bleed into the output. + 2. Group by (partition, bucket); iterate groups in sorted-key order. + 3. Per group, call :meth:`_slice_group_into_chunks` to produce a list + of segment lists — one segment list per chunk. + 4. Wrap each chunk with its (partition, bucket) into ``_Chunk``, + concatenate across groups. + 5. ``random.Random(seed).shuffle`` all chunks. + 6. If sharded, take this worker's slice via balanced ``_compute_shard_range``. + 7. Map each chunk through :meth:`_chunk_to_split`. + + Subclasses implement the three abstract hooks. Reader paths + (``RawFileSplitRead`` for append, ``DataEvolutionSplitRead`` for DE) + are unchanged because chunks ride on existing wrappers + (``SlicedSplit`` / ``IndexedSplit``). + """ + + def __init__( + self, + table, + target_split_size: int, + open_file_cost: int, + deletion_files_map=None, + seed: int = 0, + chunk_size: int = 0, + ): + super().__init__(table, target_split_size, open_file_cost, deletion_files_map) + self.seed = seed + self.chunk_size = chunk_size + + def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: + if not file_entries: + return [] + + sorted_entries = sorted(file_entries, key=self._sort_key) + + partitioned: "defaultdict[Tuple[tuple, int], List[ManifestEntry]]" = defaultdict(list) + for entry in sorted_entries: + partitioned[(tuple(entry.partition.values), entry.bucket)].append(entry) + + all_chunks: List[_Chunk] = [] + for key in sorted( + partitioned.keys(), + key=lambda k: (_null_safe_partition_key(k[0]), k[1]), + ): + entries_in_group = partitioned[key] + partition_row = entries_in_group[0].partition + bucket = entries_in_group[0].bucket + for segments in self._slice_group_into_chunks(entries_in_group): + all_chunks.append(_Chunk(partition_row, bucket, segments)) + + rng = random.Random(self.seed) + rng.shuffle(all_chunks) + + if self.idx_of_this_subtask is not None: + start, end = self._compute_shard_range(len(all_chunks)) + all_chunks = all_chunks[start:end] + + return [self._chunk_to_split(c) for c in all_chunks] + + @abstractmethod + def _sort_key(self, entry: ManifestEntry): + """Return a comparable, deterministic key for stable sort.""" + + @abstractmethod + def _slice_group_into_chunks(self, entries: List[ManifestEntry]) -> List[List[Any]]: + """Cut one (partition, bucket) group into chunks of segments. + + Each returned inner list represents one chunk; segment shape is + subclass-defined. + """ + + @abstractmethod + def _chunk_to_split(self, chunk: _Chunk) -> Split: + """Wrap a chunk into a Split that the existing readers consume.""" + + +# --------------------------------------------------------------------------- +# Append (non-DE, non-DV) implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _FileSegment: + """A contiguous slice of a data file inside one chunk. + + start/end are half-open row offsets within the file when the chunk + boundary falls inside the file; both are None when the chunk owns + the full file (so SlicedSplit's shard_file_idx_map can skip it and + treat the file as full — see sliced_split.py:73-78). + """ + file: DataFileMeta + start: Optional[int] + end: Optional[int] + + +class AppendChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase): + """Chunk-shuffled splits for plain append tables (non-PK, non-DV, non-DE).""" + + def _sort_key(self, entry: ManifestEntry): + return ( + _null_safe_partition_key(entry.partition.values), + entry.bucket, + entry.file.file_name, + ) + + def _slice_group_into_chunks( + self, entries: List[ManifestEntry] + ) -> List[List[_FileSegment]]: + """Cut a (partition, bucket) group into chunks of at most + ``self.chunk_size`` rows. ``chunk_size`` is a hard upper bound: + the last chunk may be smaller, but no chunk exceeds it. + """ + chunks: List[List[_FileSegment]] = [] + current: List[_FileSegment] = [] + current_rows = 0 + + for entry in entries: + file = entry.file + offset = 0 + remaining = file.row_count + while remaining > 0: + avail = self.chunk_size - current_rows + if avail <= 0: + chunks.append(current) + current = [] + current_rows = 0 + avail = self.chunk_size + + take = min(remaining, avail) + + if take == file.row_count and offset == 0: + current.append(_FileSegment(file, None, None)) + else: + current.append(_FileSegment(file, offset, offset + take)) + + current_rows += take + offset += take + remaining -= take + + if current: + chunks.append(current) + + return chunks + + def _chunk_to_split(self, chunk: _Chunk) -> Split: + files: List[DataFileMeta] = [] + shard_file_idx_map = {} + for seg in chunk.segments: + files.append(seg.file) + if seg.start is not None and seg.end is not None: + shard_file_idx_map[seg.file.file_name] = (seg.start, seg.end) + + for f in files: + f.set_file_path( + self.table.table_path, + chunk.partition, + chunk.bucket, + self.default_part_value, + ) + + data_split = DataSplit( + files=files, + partition=chunk.partition, + bucket=chunk.bucket, + raw_convertible=True, + data_deletion_files=None, + ) + + if shard_file_idx_map: + return SlicedSplit(data_split, shard_file_idx_map) + return data_split + + +# --------------------------------------------------------------------------- +# Data Evolution implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _AlignedGroupSegment: + """A row_id sub-range over one row-id-aligned file group. + + ``files`` is the entire group (may include blob/vector siblings), + so the reader sees every column file even when only a slice of the + group's row_id range lands in this chunk. ``row_range`` is the + inclusive global row_id range this segment owns. + """ + files: List[DataFileMeta] + row_range: Range + + +class DataEvolutionChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase): + """Chunk-shuffled splits for data-evolution append tables. + + The minimum cuttable unit is a row_id-aligned file group: cutting + inside one group would orphan column files relative to the row_id + range, so we keep groups intact and only slice along their row_id + axis. Each chunk maps to an :class:`IndexedSplit` whose ``row_ranges`` + bound the readable slice for that chunk. + """ + + def _sort_key(self, entry: ManifestEntry): + first_row_id = ( + entry.file.first_row_id + if entry.file.first_row_id is not None + else float('-inf') + ) + is_special = 1 if ( + DataFileMeta.is_blob_file(entry.file.file_name) + or DataFileMeta.is_vector_file(entry.file.file_name) + ) else 0 + return ( + _null_safe_partition_key(entry.partition.values), + entry.bucket, + first_row_id, + is_special, + entry.file.file_name, + ) + + def _slice_group_into_chunks( + self, entries: List[ManifestEntry] + ) -> List[List[_AlignedGroupSegment]]: + files = [e.file for e in entries] + # (Range, [files]) pairs sorted by row_id — see helper docstring. + aligned_groups = self._split_by_row_id_with_range(files) + + chunks: List[List[_AlignedGroupSegment]] = [] + current: List[_AlignedGroupSegment] = [] + current_rows = 0 + + for group_range, group_files in aligned_groups: + offset = 0 + group_rows = group_range.count() + while offset < group_rows: + avail = self.chunk_size - current_rows + if avail <= 0: + chunks.append(current) + current = [] + current_rows = 0 + avail = self.chunk_size + + take = min(group_rows - offset, avail) + seg_range = Range( + group_range.from_ + offset, + group_range.from_ + offset + take - 1, + ) + current.append(_AlignedGroupSegment(group_files, seg_range)) + current_rows += take + offset += take + + if current: + chunks.append(current) + + return chunks + + def _chunk_to_split(self, chunk: _Chunk) -> Split: + all_files: List[DataFileMeta] = [] + seen_file_names = set() + row_ranges: List[Range] = [] + + for seg in chunk.segments: + for f in seg.files: + if f.file_name not in seen_file_names: + seen_file_names.add(f.file_name) + all_files.append(f) + row_ranges.append(seg.row_range) + + for f in all_files: + f.set_file_path( + self.table.table_path, + chunk.partition, + chunk.bucket, + self.default_part_value, + ) + + row_ranges.sort(key=lambda r: r.from_) + + data_split = DataSplit( + files=all_files, + partition=chunk.partition, + bucket=chunk.bucket, + raw_convertible=False, + data_deletion_files=None, + ) + return IndexedSplit(data_split, row_ranges, scores=None) + + @staticmethod + def _split_by_row_id_with_range( + files: List[DataFileMeta], + ) -> List[Tuple[Range, List[DataFileMeta]]]: + """Group files by overlapping row_id range, returning (range, files) + pairs sorted by ``range.from_``. + + Mirrors :meth:`DataEvolutionSplitGenerator._split_by_row_id` but + also returns the merged row_id range per group, which the chunk + slicer needs to drive row-count accumulation. Files without + ``first_row_id`` are skipped (DE invariant guarantees presence; + defensive in case stray entries sneak in). + """ + list_ranges = [f.row_id_range() for f in files if f.row_id_range() is not None] + if not list_ranges: + return [] + sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False) + + range_to_files: "dict[Range, List[DataFileMeta]]" = {} + for f in files: + file_range = f.row_id_range() + if file_range is None: + continue + for r in sorted_ranges: + if r.overlaps(file_range): + range_to_files.setdefault(r, []).append(f) + break + + return sorted(range_to_files.items(), key=lambda kv: kv[0].from_) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 1d2831c194e3..8dcfb07ad750 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -40,6 +40,10 @@ AppendTableSplitGenerator from pypaimon.read.scanner.bucket_select_converter import \ create_bucket_selector +from pypaimon.read.scanner.chunk_shuffle_split_generator import ( + AppendChunkShuffleSplitGenerator, + DataEvolutionChunkShuffleSplitGenerator, +) from pypaimon.read.scanner.data_evolution_split_generator import \ DataEvolutionSplitGenerator from pypaimon.read.scanner.primary_key_table_split_generator import \ @@ -204,6 +208,7 @@ def __init__( self.number_of_para_subtasks = None self.start_pos_of_this_subtask = None self.end_pos_of_this_subtask = None + self.chunk_shuffle: Optional[Tuple[int, int]] = None self.only_read_real_buckets = options.bucket() == BucketMode.POSTPONE_BUCKET.value self.data_evolution = options.data_evolution_enabled() @@ -243,7 +248,34 @@ def _deletion_files_map(self, entries: List[ManifestEntry]) -> Dict[tuple, Dict[ def scan(self) -> Plan: start_ms = time.time() * 1000 # Create appropriate split generator based on table type - if self.table.is_primary_key_table: + if self.chunk_shuffle is not None: + self._validate_chunk_shuffle_compat() + seed, chunk_size = self.chunk_shuffle + # Both append and DE paths use plan_files() directly: the + # predicate is partition-only (enforced by + # _validate_chunk_shuffle_compat), so manifest_entry-level + # partition pruning in plan_files() is the only filter we + # want — no row_id range pushdown, no global index lookup. + entries = self.plan_files() + if self.data_evolution: + split_generator = DataEvolutionChunkShuffleSplitGenerator( + self.table, + self.target_split_size, + self.open_file_cost, + self._deletion_files_map(entries), + seed=seed, + chunk_size=chunk_size, + ) + else: + split_generator = AppendChunkShuffleSplitGenerator( + self.table, + self.target_split_size, + self.open_file_cost, + self._deletion_files_map(entries), + seed=seed, + chunk_size=chunk_size, + ) + elif self.table.is_primary_key_table: entries = self.plan_files() split_generator = PrimaryKeyTableSplitGenerator( self.table, @@ -425,6 +457,38 @@ def with_global_index_result(self, result) -> 'FileScanner': self._global_index_result = result return self + def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'FileScanner': + if not isinstance(seed, int): + raise ValueError("chunk_shuffle seed must be an int") + if not isinstance(chunk_size, int) or chunk_size <= 0: + raise ValueError("chunk_shuffle chunk_size must be a positive int") + self.chunk_shuffle = (seed, chunk_size) + return self + + def _validate_chunk_shuffle_compat(self) -> None: + if self.table.is_primary_key_table: + raise ValueError("chunk_shuffle only supports append tables") + if self.deletion_vectors_enabled: + raise ValueError("chunk_shuffle not supported with deletion vectors") + if self.start_pos_of_this_subtask is not None: + raise ValueError("chunk_shuffle cannot combine with with_slice") + if self.limit is not None: + raise ValueError("chunk_shuffle cannot combine with limit") + if self._global_index_result is not None: + raise ValueError("chunk_shuffle cannot combine with global index") + # Only partition predicates are allowed: row-level / column-level + # predicates would silently shrink each chunk's effective row count, + # breaking the chunk_size contract DataLoader callers expect. + if self.predicate is not None: + partition_keys = set(self.table.partition_keys or []) + non_partition_fields = _get_all_fields(self.predicate) - partition_keys + if non_partition_fields: + raise ValueError( + "chunk_shuffle predicate must reference only partition keys; " + "got non-partition fields: " + f"{sorted(non_partition_fields)}" + ) + def scan_with_stats(self) -> Tuple[Plan, ScanStats]: """Run one scan pass while recording :class:`ScanStats` counters. diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 623261803503..06eafc38d2d9 100755 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -158,3 +158,7 @@ def with_slice(self, start_pos, end_pos) -> 'TableScan': def with_global_index_result(self, result) -> 'TableScan': self.file_scanner.with_global_index_result(result) return self + + def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'TableScan': + self.file_scanner.with_chunk_shuffle(seed, chunk_size) + return self diff --git a/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py new file mode 100644 index 000000000000..00552c0b9d15 --- /dev/null +++ b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for ChunkShuffleSplitGenerator and TableScan.with_chunk_shuffle. + +Algorithmic tests use Mock entries so they don't touch disk; the +end-to-end test writes a real append table and validates that all +workers together cover the data exactly once. +""" + +import os +import shutil +import tempfile +import unittest +from unittest.mock import Mock + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.globalindex.indexed_split import IndexedSplit +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.read.scanner.chunk_shuffle_split_generator import ( + AppendChunkShuffleSplitGenerator, + DataEvolutionChunkShuffleSplitGenerator, +) +from pypaimon.read.sliced_split import SlicedSplit +from pypaimon.read.split import DataSplit +from pypaimon.utils.range import Range + + +def _mock_table(table_path='/tmp/_chunk_shuffle_test_path'): + table = Mock() + table.table_path = table_path + table.options = Mock() + return table + + +def _mock_entry(partition_values, bucket, file_name, row_count, file_size=1024): + entry = Mock() + entry.partition = Mock() + entry.partition.values = partition_values + entry.bucket = bucket + entry.file = Mock() + entry.file.file_name = file_name + entry.file.file_size = file_size + entry.file.row_count = row_count + # Swallow set_file_path so we don't need to mock partition path encoding. + entry.file.set_file_path = Mock() + return entry + + +def _make_generator(seed, chunk_size, table=None): + if table is None: + table = _mock_table() + return AppendChunkShuffleSplitGenerator( + table, + target_split_size=128 * 1024 * 1024, + open_file_cost=4 * 1024 * 1024, + deletion_files_map=None, + seed=seed, + chunk_size=chunk_size, + ) + + +def _make_de_generator(seed, chunk_size, table=None): + if table is None: + table = _mock_table() + return DataEvolutionChunkShuffleSplitGenerator( + table, + target_split_size=128 * 1024 * 1024, + open_file_cost=4 * 1024 * 1024, + deletion_files_map=None, + seed=seed, + chunk_size=chunk_size, + ) + + +def _mock_de_entry(partition_values, bucket, file_name, first_row_id, row_count, file_size=1024): + """A DE-flavoured mock entry: file carries first_row_id and a real + Range so :meth:`row_id_range` and ``Range.overlaps`` work.""" + entry = Mock() + entry.partition = Mock() + entry.partition.values = partition_values + entry.bucket = bucket + file = Mock(spec=DataFileMeta) + file.file_name = file_name + file.file_size = file_size + file.row_count = row_count + file.first_row_id = first_row_id + file.row_id_range = lambda f=first_row_id, c=row_count: Range(f, f + c - 1) + file.set_file_path = Mock() + entry.file = file + return entry + + +def _split_signature(split): + """A stable, comparable identity for a split — what the worker would actually read.""" + if isinstance(split, SlicedSplit): + underlying = split.data_split() + files = tuple(f.file_name for f in underlying.files) + idx_map = tuple(sorted(split.shard_file_idx_map().items())) + return (tuple(underlying.partition.values), underlying.bucket, files, idx_map) + if isinstance(split, IndexedSplit): + underlying = split.data_split() + files = tuple(sorted(f.file_name for f in underlying.files)) + ranges = tuple((r.from_, r.to) for r in split.row_ranges()) + return (tuple(underlying.partition.values), underlying.bucket, files, ranges) + if isinstance(split, DataSplit): + files = tuple(f.file_name for f in split.files) + return (tuple(split.partition.values), split.bucket, files, ()) + raise AssertionError("unexpected split type: %r" % type(split)) + + +def _split_rows(split): + """Effective row count this split actually exposes.""" + return split.row_count + + +class ChunkShuffleSplitGeneratorAlgoTest(unittest.TestCase): + + def test_no_entries_returns_empty(self): + gen = _make_generator(seed=1, chunk_size=100) + self.assertEqual(gen.create_splits([]), []) + + def test_full_files_no_truncation(self): + entries = [ + _mock_entry([], 0, 'f1', 100), + _mock_entry([], 0, 'f2', 100), + _mock_entry([], 0, 'f3', 100), + ] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + # 3 chunks, each holding exactly one whole file → all DataSplit, no SlicedSplit + self.assertEqual(len(splits), 3) + for s in splits: + self.assertIsInstance(s, DataSplit) + self.assertEqual(s.row_count, 100) + + def test_chunk_truncates_inside_file(self): + # one file of 250 rows, chunk_size 100 → 3 chunks: 100, 100, 50 + entries = [_mock_entry([], 0, 'f1', 250)] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + # All three chunks slice the same file → all SlicedSplit + for s in splits: + self.assertIsInstance(s, SlicedSplit) + # union of (start, end) intervals must cover [0, 250) + intervals = sorted(s.shard_file_idx_map()['f1'] for s in splits) + self.assertEqual(intervals, [(0, 100), (100, 200), (200, 250)]) + total = sum(end - start for start, end in intervals) + self.assertEqual(total, 250) + + def test_chunk_spans_multiple_files(self): + # f1=30, f2=30, f3=30, chunk_size=50 → chunks: [f1(30)+f2(0,20)], [f2(20,30)+f3(0,40 cap 30=30)] ... + entries = [ + _mock_entry([], 0, 'f1', 30), + _mock_entry([], 0, 'f2', 30), + _mock_entry([], 0, 'f3', 30), + ] + gen = _make_generator(seed=1, chunk_size=50) + splits = gen.create_splits(entries) + # total 90 rows, chunk_size 50 → 2 chunks (50 + 40) + self.assertEqual(len(splits), 2) + total_rows = sum(_split_rows(s) for s in splits) + self.assertEqual(total_rows, 90) + + def test_chunk_size_larger_than_total(self): + entries = [ + _mock_entry([], 0, 'f1', 30), + _mock_entry([], 0, 'f2', 30), + ] + gen = _make_generator(seed=1, chunk_size=1000) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 1) + # No truncation — full files inside one chunk → DataSplit not SlicedSplit + self.assertIsInstance(splits[0], DataSplit) + self.assertEqual(_split_rows(splits[0]), 60) + + def test_deterministic_same_seed_same_order(self): + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(20)] + gen1 = _make_generator(seed=42, chunk_size=50) + gen2 = _make_generator(seed=42, chunk_size=50) + splits1 = gen1.create_splits(entries) + splits2 = gen2.create_splits(entries) + self.assertEqual( + [_split_signature(s) for s in splits1], + [_split_signature(s) for s in splits2], + ) + + def test_different_seed_different_order(self): + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(50)] + gen1 = _make_generator(seed=1, chunk_size=100) + gen2 = _make_generator(seed=2, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + # Same set of chunks, different order — high probability they differ on 50 items + self.assertEqual(sorted(sigs1), sorted(sigs2)) + self.assertNotEqual(sigs1, sigs2) + + def test_shuffle_actually_reorders(self): + # 20 files in scan order f0..f19. After shuffle the file order should not be sorted. + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(20)] + gen = _make_generator(seed=42, chunk_size=100) + splits = gen.create_splits(entries) + file_names = [s.files[0].file_name for s in splits] + self.assertNotEqual(file_names, sorted(file_names)) + + def test_shard_round_trip_no_overlap_no_loss(self): + # 13 files × 100 rows = 1300 rows. 4 workers. + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(13)] + num_workers = 4 + all_sigs = [] + total_rows = 0 + for worker in range(num_workers): + gen = _make_generator(seed=7, chunk_size=100) + gen.with_shard(worker, num_workers) + splits = gen.create_splits(list(entries)) # copy: shuffle is in-place on chunks list + for s in splits: + all_sigs.append(_split_signature(s)) + total_rows += _split_rows(s) + self.assertEqual(total_rows, 13 * 100) + # No duplicate chunks across workers + self.assertEqual(len(all_sigs), len(set(all_sigs))) + # All chunks together equal an unsharded run + unsharded = _make_generator(seed=7, chunk_size=100).create_splits(list(entries)) + self.assertEqual( + sorted(all_sigs), + sorted(_split_signature(s) for s in unsharded), + ) + + def test_shard_balanced_distribution(self): + # 10 chunks across 3 workers → 4, 3, 3 (front-loaded by _compute_shard_range) + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(10)] + counts = [] + for worker in range(3): + gen = _make_generator(seed=0, chunk_size=100) + gen.with_shard(worker, 3) + counts.append(len(gen.create_splits(list(entries)))) + self.assertEqual(sorted(counts, reverse=True), [4, 3, 3]) + + def test_chunks_fewer_than_workers(self): + # 2 chunks, 5 workers → 3 workers get nothing + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(2)] + empties = 0 + non_empties = 0 + for worker in range(5): + gen = _make_generator(seed=0, chunk_size=100) + gen.with_shard(worker, 5) + n = len(gen.create_splits(list(entries))) + if n == 0: + empties += 1 + else: + non_empties += 1 + self.assertEqual(n, 1) + self.assertEqual(empties, 3) + self.assertEqual(non_empties, 2) + + def test_multi_partition_no_chunk_crosses_partition(self): + entries = [ + _mock_entry(['p1'], 0, 'f1', 100), + _mock_entry(['p1'], 0, 'f2', 100), + _mock_entry(['p2'], 0, 'f3', 100), + _mock_entry(['p2'], 0, 'f4', 100), + ] + gen = _make_generator(seed=0, chunk_size=100) + splits = gen.create_splits(entries) + # Each split's underlying files come from one partition only + for s in splits: + partitions_in_files = set() + data_split = s.data_split() if isinstance(s, SlicedSplit) else s + partitions_in_files.add(tuple(data_split.partition.values)) + self.assertEqual(len(partitions_in_files), 1) + + def test_null_and_non_null_partitions_sort_safely(self): + # Mixing null and non-null partition values used to raise + # ``TypeError: '<' not supported between instances of 'NoneType' and 'str'`` + # before _null_safe_partition_key. Validate planning succeeds and + # both partitions produce splits. + entries = [ + _mock_entry(['p1'], 0, 'f1', 100), + _mock_entry([None], 0, 'f2', 100), + _mock_entry(['p2'], 0, 'f3', 100), + ] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + partitions = {tuple(_split_signature(s)[0]) for s in splits} + self.assertEqual(partitions, {('p1',), ('p2',), (None,)}) + + def test_input_order_does_not_affect_output_when_same_files(self): + """Manifest read parallelism shouldn't bleed through — sorting is internal.""" + a = _mock_entry([], 0, 'f1', 100) + b = _mock_entry([], 0, 'f2', 100) + c = _mock_entry([], 0, 'f3', 100) + gen1 = _make_generator(seed=99, chunk_size=100) + gen2 = _make_generator(seed=99, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])] + sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])] + self.assertEqual(sigs1, sigs2) + + +class ChunkShuffleEndToEndTest(unittest.TestCase): + """Real append table → with_chunk_shuffle → multiple workers → union == original.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_append_table(self, name, partition_keys=None): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, partition_keys=partition_keys or []) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + return self.catalog.get_table(identifier), pa_schema + + def _write_n_batches(self, table, pa_schema, batches): + wb = table.new_batch_write_builder() + for batch in batches: + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + def test_workers_union_equals_full_table(self): + table, pa_schema = self._create_append_table('cs_union') + # 4 commits × 50 rows = 200 rows across several files + batches = [] + for c in range(4): + base = c * 50 + batches.append({ + 'id': list(range(base, base + 50)), + 'value': [f'v{i}' for i in range(base, base + 50)], + 'part': ['p1'] * 50, + }) + self._write_n_batches(table, pa_schema, batches) + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + num_workers = 3 + worker_tables = [] + for w in range(num_workers): + scan = read_builder.new_scan() \ + .with_chunk_shuffle(seed=123, chunk_size=37) \ + .with_shard(w, num_workers) + splits = scan.plan().splits() + if splits: + worker_tables.append(table_read.to_arrow(splits)) + + actual = pa.concat_tables(worker_tables).sort_by('id') if worker_tables else None + self.assertIsNotNone(actual) + self.assertEqual(actual.num_rows, 200) + self.assertEqual(actual.column('id').to_pylist(), list(range(200))) + + def test_deterministic_plan_across_calls(self): + table, pa_schema = self._create_append_table('cs_determinism') + self._write_n_batches(table, pa_schema, [{ + 'id': list(range(100)), + 'value': [f'v{i}' for i in range(100)], + 'part': ['p'] * 100, + }]) + + def plan_files(worker): + scan = table.new_read_builder().new_scan() \ + .with_chunk_shuffle(seed=42, chunk_size=20) \ + .with_shard(worker, 3) + return [_split_signature(s) for s in scan.plan().splits()] + + for worker in range(3): + self.assertEqual(plan_files(worker), plan_files(worker)) + + +class ChunkShuffleCompatibilityTest(unittest.TestCase): + """Validates the reject-on-incompatible-combination matrix.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _append_table(self, name, options=None, partition_keys=None): + if partition_keys: + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + else: + pa_schema = pa.schema([('id', pa.int64()), ('value', pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, partition_keys=partition_keys, options=options or {}) + self.catalog.create_table(f'default.{name}', schema, False) + return self.catalog.get_table(f'default.{name}') + + def _pk_table(self, name): + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('value', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], options={'bucket': '1'}) + self.catalog.create_table(f'default.{name}', schema, False) + return self.catalog.get_table(f'default.{name}') + + def test_pk_table_rejected(self): + table = self._pk_table('cs_pk') + scan = table.new_read_builder().new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "only supports append tables"): + scan.plan() + + def test_dv_table_rejected(self): + table = self._append_table('cs_dv', options={'deletion-vectors.enabled': 'true'}) + scan = table.new_read_builder().new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "deletion vectors"): + scan.plan() + + def test_with_slice_then_chunk_shuffle_rejected(self): + table = self._append_table('cs_slice') + scan = table.new_read_builder().new_scan() + scan.with_slice(0, 100).with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "with_slice"): + scan.plan() + + def test_limit_with_chunk_shuffle_rejected(self): + table = self._append_table('cs_limit') + scan = table.new_read_builder().with_limit(50).new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "limit"): + scan.plan() + + def test_invalid_chunk_size(self): + table = self._append_table('cs_invalid') + scan = table.new_read_builder().new_scan() + with self.assertRaisesRegex(ValueError, "chunk_size"): + scan.with_chunk_shuffle(seed=1, chunk_size=0) + with self.assertRaisesRegex(ValueError, "chunk_size"): + scan.with_chunk_shuffle(seed=1, chunk_size=-5) + + def test_column_predicate_rejected(self): + # Non-partition predicate would silently shrink effective chunk + # row counts inside the reader → not allowed. + table = self._append_table('cs_col_pred', partition_keys=['part']) + rb = table.new_read_builder() + col_pred = rb.new_predicate_builder().equal('id', 5) + rb = rb.with_filter(col_pred) + scan = rb.new_scan().with_chunk_shuffle(seed=1, chunk_size=10) + with self.assertRaisesRegex(ValueError, "partition keys"): + scan.plan() + + def test_partition_predicate_allowed(self): + # Filter is partition-only → must succeed and read only the + # matching partition. + table, pa_schema = self._partitioned_table_with_data('cs_part_pred') + + rb = table.new_read_builder() + pred = rb.new_predicate_builder().equal('part', 'p1') + scan = rb.with_filter(pred).new_scan() \ + .with_chunk_shuffle(seed=1, chunk_size=10) + plan = scan.plan() + # All splits should be from partition 'p1' + for split in plan.splits(): + partition_values = split.partition.values + self.assertEqual(tuple(partition_values), ('p1',)) + + def _partitioned_table_with_data(self, name): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['part']) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + table = self.catalog.get_table(identifier) + wb = table.new_batch_write_builder() + for part, ids in [('p1', range(50)), ('p2', range(50, 100))]: + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': list(ids), + 'value': [f'v{i}' for i in ids], + 'part': [part] * 50}, + schema=pa_schema, + )) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + return table, pa_schema + + +class DataEvolutionChunkShuffleAlgoTest(unittest.TestCase): + """Mock-based tests for the DE chunk slicer.""" + + def test_no_entries_returns_empty(self): + gen = _make_de_generator(seed=1, chunk_size=100) + self.assertEqual(gen.create_splits([]), []) + + def test_full_aligned_groups_one_per_chunk(self): + # Three commits of 100 rows each → three aligned groups. + # chunk_size = 100 → 3 chunks, each holding one group whole. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g1.parquet', 100, 100), + _mock_de_entry([], 0, 'g2.parquet', 200, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + for s in splits: + self.assertIsInstance(s, IndexedSplit) + self.assertEqual(s.row_count, 100) + self.assertEqual(len(s.row_ranges()), 1) + + def test_aligned_group_split_across_chunks(self): + # One 250-row group, chunk_size=100 → 3 chunks (100, 100, 50). + # All three chunks reference the SAME aligned group's files but + # each carries a distinct row_range slice. + entries = [_mock_de_entry([], 0, 'g0.parquet', 1000, 250)] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + + # Union of the three chunks' row_ranges must cover the whole group [1000, 1249]. + ranges = [] + for s in splits: + self.assertIsInstance(s, IndexedSplit) + ranges.extend((r.from_, r.to) for r in s.row_ranges()) + ranges.sort() + self.assertEqual(ranges, [(1000, 1099), (1100, 1199), (1200, 1249)]) + total = sum(r[1] - r[0] + 1 for r in ranges) + self.assertEqual(total, 250) + + def test_chunk_pulls_in_blob_siblings(self): + # One aligned group with a main parquet and a blob sibling sharing the + # row_id range. A single chunk must include BOTH files so the reader + # can union the columns. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g0.blob', 0, 100), # .blob ext → is_blob_file + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 1) + files = sorted(f.file_name for f in splits[0].files) + self.assertEqual(files, ['g0.blob', 'g0.parquet']) + + def test_blob_propagates_when_group_split(self): + # Same scenario but chunk_size halves the group → the blob sibling + # must appear in BOTH chunk splits. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g0.blob', 0, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=50) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 2) + for s in splits: + files = sorted(f.file_name for f in s.files) + self.assertEqual(files, ['g0.blob', 'g0.parquet']) + + def test_deterministic_same_seed(self): + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(20)] + gen1 = _make_de_generator(seed=42, chunk_size=100) + gen2 = _make_de_generator(seed=42, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + self.assertEqual(sigs1, sigs2) + + def test_different_seed_reorders(self): + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(50)] + gen1 = _make_de_generator(seed=1, chunk_size=100) + gen2 = _make_de_generator(seed=2, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + self.assertEqual(sorted(sigs1), sorted(sigs2)) + self.assertNotEqual(sigs1, sigs2) + + def test_input_order_does_not_affect_output(self): + a = _mock_de_entry([], 0, 'g0.parquet', 0, 100) + b = _mock_de_entry([], 0, 'g1.parquet', 100, 100) + c = _mock_de_entry([], 0, 'g2.parquet', 200, 100) + gen1 = _make_de_generator(seed=99, chunk_size=100) + gen2 = _make_de_generator(seed=99, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])] + sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])] + self.assertEqual(sigs1, sigs2) + + def test_shard_round_trip_no_overlap_no_loss(self): + # 13 aligned groups × 100 rows = 1300 rows. Shard across 4 workers. + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(13)] + num_workers = 4 + + unsharded = _make_de_generator(seed=7, chunk_size=100).create_splits(list(entries)) + unsharded_sigs = sorted(_split_signature(s) for s in unsharded) + + sharded_sigs = [] + total_rows = 0 + for w in range(num_workers): + gen = _make_de_generator(seed=7, chunk_size=100) + gen.with_shard(w, num_workers) + for s in gen.create_splits(list(entries)): + sharded_sigs.append(_split_signature(s)) + total_rows += s.row_count + self.assertEqual(total_rows, 13 * 100) + # No duplicate splits across workers + self.assertEqual(len(sharded_sigs), len(set(sharded_sigs))) + self.assertEqual(sorted(sharded_sigs), unsharded_sigs) + + def test_multi_partition_no_chunk_crosses_partition(self): + entries = [ + _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100), + _mock_de_entry(['p1'], 0, 'g1.parquet', 100, 100), + _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100), + _mock_de_entry(['p2'], 0, 'g3.parquet', 300, 100), + ] + gen = _make_de_generator(seed=0, chunk_size=100) + splits = gen.create_splits(entries) + for s in splits: + data_split = s.data_split() if isinstance(s, IndexedSplit) else s + self.assertEqual(len({tuple(data_split.partition.values)}), 1) + + def test_null_and_non_null_partitions_sort_safely(self): + # Same null-vs-non-null sort guard, exercised on the DE path. + entries = [ + _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100), + _mock_de_entry([None], 0, 'g1.parquet', 100, 100), + _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + partitions = {_split_signature(s)[0] for s in splits} + self.assertEqual(partitions, {('p1',), ('p2',), (None,)}) + + +class DataEvolutionChunkShuffleEndToEndTest(unittest.TestCase): + """Real DE table → with_chunk_shuffle → multi-worker → union == full table.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_de_table(self, name): + pa_schema = pa.schema([ + ('id', pa.int32()), + ('value', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }, + ) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + return self.catalog.get_table(identifier), pa_schema + + def _commit_full_rows(self, table, pa_schema, ids): + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': ids, 'value': [f'v{i}' for i in ids]}, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + def test_workers_union_equals_full_table(self): + table, pa_schema = self._create_de_table('cs_de_union') + # 4 commits → 4 aligned groups, one file each (full-column writes). + for c in range(4): + base = c * 50 + self._commit_full_rows(table, pa_schema, list(range(base, base + 50))) + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + num_workers = 3 + worker_tables = [] + for w in range(num_workers): + scan = read_builder.new_scan() \ + .with_chunk_shuffle(seed=123, chunk_size=37) \ + .with_shard(w, num_workers) + splits = scan.plan().splits() + if splits: + worker_tables.append(table_read.to_arrow(splits)) + + actual = pa.concat_tables(worker_tables).sort_by('id') + self.assertEqual(actual.num_rows, 200) + self.assertEqual(actual.column('id').to_pylist(), list(range(200))) + + def test_deterministic_plan_across_calls(self): + table, pa_schema = self._create_de_table('cs_de_determinism') + for c in range(3): + base = c * 40 + self._commit_full_rows(table, pa_schema, list(range(base, base + 40))) + + def plan_sigs(worker): + scan = table.new_read_builder().new_scan() \ + .with_chunk_shuffle(seed=42, chunk_size=15) \ + .with_shard(worker, 4) + return [_split_signature(s) for s in scan.plan().splits()] + + for worker in range(4): + self.assertEqual(plan_sigs(worker), plan_sigs(worker)) + + +if __name__ == '__main__': + unittest.main()