diff --git a/.github/workflows/paimon-python-checks.yml b/.github/workflows/paimon-python-checks.yml index 4fb7fe07e481..ff2929c0fa2a 100755 --- a/.github/workflows/paimon-python-checks.yml +++ b/.github/workflows/paimon-python-checks.yml @@ -46,7 +46,7 @@ jobs: container: "python:${{ matrix.python-version }}-slim" strategy: matrix: - python-version: ['3.6.15', '3.10'] + python-version: [ '3.6.15', '3.10' ] steps: - name: Checkout code @@ -70,6 +70,7 @@ jobs: build-essential \ git \ curl \ + && apt-get clean \ && rm -rf /var/lib/apt/lists/* - name: Verify Java and Maven installation @@ -88,21 +89,24 @@ jobs: - name: Install Python dependencies shell: bash run: | + df -h if [[ "${{ matrix.python-version }}" == "3.6.15" ]]; then python -m pip install --upgrade pip==21.3.1 python --version - python -m pip install -q pyroaring readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null + python -m pip install --no-cache-dir pyroaring readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null else python -m pip install --upgrade pip - python -m pip install -q pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 2>&1 >/dev/null + pip install torch --index-url https://download.pytorch.org/whl/cpu + python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 fi + df -h - name: Run lint-python.sh shell: bash run: | chmod +x paimon-python/dev/lint-python.sh - ./paimon-python/dev/lint-python.sh + ./paimon-python/dev/lint-python.sh -e pytest_torch - requirement_version_compatible_test: + torch_test: runs-on: ubuntu-latest container: "python:3.10-slim" @@ -110,17 +114,6 @@ jobs: - name: Checkout code uses: actions/checkout@v2 - - name: Set up JDK ${{ env.JDK_VERSION }} - uses: actions/setup-java@v4 - with: - java-version: ${{ env.JDK_VERSION }} - distribution: 'temurin' - - - name: Set up Maven - uses: stCarolas/setup-maven@v4.5 - with: - maven-version: 3.8.8 - - name: Install system dependencies shell: bash run: | @@ -128,26 +121,50 @@ jobs: build-essential \ git \ curl \ + && apt-get clean \ && rm -rf /var/lib/apt/lists/* - - name: Verify Java and Maven installation - run: | - java -version - mvn -version - - name: Verify Python version run: python --version - - name: Build Java + - name: Install Python dependencies + shell: bash run: | - echo "Start compiling modules" - mvn -T 2C -B clean install -DskipTests + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu + python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 + - name: Run lint-python.sh + shell: bash + run: | + chmod +x paimon-python/dev/lint-python.sh + ./paimon-python/dev/lint-python.sh -i pytest_torch + + requirement_version_compatible_test: + runs-on: ubuntu-latest + container: "python:3.10-slim" + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Install system dependencies + shell: bash + run: | + apt-get update && apt-get install -y \ + build-essential \ + git \ + curl \ + && rm -rf /var/lib/apt/lists/* + + - name: Verify Python version + run: python --version - name: Install base Python dependencies shell: bash run: | python -m pip install --upgrade pip - python -m pip install -q \ + pip install torch --index-url https://download.pytorch.org/whl/cpu + python -m pip install --no-cache-dir \ pyroaring \ readerwriterlock==1.0.9 \ fsspec==2024.3.1 \ @@ -165,36 +182,37 @@ jobs: requests \ parameterized==0.9.0 \ packaging + - name: Test requirement version compatibility shell: bash run: | cd paimon-python - + # Test Ray version compatibility echo "==========================================" echo "Testing Ray version compatibility" echo "==========================================" for ray_version in 2.44.0 2.48.0 2.53.0; do echo "Testing Ray version: $ray_version" - + # Install specific Ray version - python -m pip install -q ray==$ray_version - + python -m pip install --no-cache-dir -q ray==$ray_version + # Verify Ray version python -c "import ray; print(f'Ray version: {ray.__version__}')" python -c "from packaging.version import parse; import ray; assert parse(ray.__version__) == parse('$ray_version'), f'Expected Ray $ray_version, got {ray.__version__}'" - + # Run tests python -m pytest pypaimon/tests/ray_data_test.py::RayDataTest -v --tb=short || { echo "Tests failed for Ray $ray_version" exit 1 } - + # Uninstall Ray to avoid conflicts python -m pip uninstall -y ray done - + # Add other dependency version tests here in the future # Example: # echo "==========================================" diff --git a/docs/content/program-api/python-api.md b/docs/content/program-api/python-api.md index 406c8c1ef69f..aa7773e049d0 100644 --- a/docs/content/program-api/python-api.md +++ b/docs/content/program-api/python-api.md @@ -72,6 +72,7 @@ catalog_options = { } catalog = CatalogFactory.create(catalog_options) ``` + {{< /tab >}} {{< /tabs >}} @@ -473,6 +474,38 @@ ray_dataset = table_read.to_ray(splits) See [Ray Data API Documentation](https://docs.ray.io/en/latest/data/api/doc/ray.data.read_datasource.html) for more details. +### Read Pytorch Dataset + +This requires `torch` to be installed. + +You can read all the data into a `torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`: + +```python +from torch.utils.data import DataLoader + +table_read = read_builder.new_read() +dataset = table_read.to_torch(splits, streaming=True) +dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=2, # Concurrency to read data + shuffle=False +) + +# Collect all data from dataloader +for batch_idx, batch_data in enumerate(dataloader): + print(batch_data) + +# output: +# {'user_id': tensor([1, 2]), 'behavior': ['a', 'b']} +# {'user_id': tensor([3, 4]), 'behavior': ['c', 'd']} +# {'user_id': tensor([5, 6]), 'behavior': ['e', 'f']} +# {'user_id': tensor([7, 8]), 'behavior': ['g', 'h']} +``` + +When the `streaming` parameter is true, it will iteratively read; +when it is false, it will read the full amount of data into memory. + ### Incremental Read This API allows reading data committed between two snapshot timestamps. The steps are as follows. @@ -671,22 +704,22 @@ Key points about shard read: The following shows the supported features of Python Paimon compared to Java Paimon: **Catalog Level** - - FileSystemCatalog - - RestCatalog + - FileSystemCatalog + - RestCatalog **Table Level** - - Append Tables - - `bucket = -1` (unaware) - - `bucket > 0` (fixed) - - Primary Key Tables - - only support deduplicate - - `bucket = -2` (postpone) - - `bucket > 0` (fixed) - - read with deletion vectors enabled - - Read/Write Operations - - Batch read and write for append tables and primary key tables - - Predicate filtering - - Overwrite semantics - - Incremental reading of Delta data - - Reading and writing blob data - - `with_shard` feature + - Append Tables + - `bucket = -1` (unaware) + - `bucket > 0` (fixed) + - Primary Key Tables + - only support deduplicate + - `bucket = -2` (postpone) + - `bucket > 0` (fixed) + - read with deletion vectors enabled + - Read/Write Operations + - Batch read and write for append tables and primary key tables + - Predicate filtering + - Overwrite semantics + - Incremental reading of Delta data + - Reading and writing blob data + - `with_shard` feature diff --git a/paimon-python/dev/lint-python.sh b/paimon-python/dev/lint-python.sh index d174b120ad4f..44be2871493e 100755 --- a/paimon-python/dev/lint-python.sh +++ b/paimon-python/dev/lint-python.sh @@ -107,7 +107,7 @@ function collect_checks() { function get_all_supported_checks() { _OLD_IFS=$IFS IFS=$'\n' - SUPPORT_CHECKS=("flake8_check" "pytest_check" "mixed_check") # control the calling sequence + SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check" "mixed_check") # control the calling sequence for fun in $(declare -F); do if [[ `regexp_match "$fun" "_check$"` = true ]]; then check_name="${fun:11}" @@ -179,7 +179,7 @@ function pytest_check() { TEST_DIR="pypaimon/tests/py36" echo "Running tests for Python 3.6: $TEST_DIR" else - TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 --ignore=pypaimon/tests/e2e" + TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 --ignore=pypaimon/tests/e2e --ignore=pypaimon/tests/torch_read_test.py" echo "Running tests for Python $PYTHON_VERSION (excluding py36): pypaimon/tests --ignore=pypaimon/tests/py36" fi @@ -197,7 +197,32 @@ function pytest_check() { print_function "STAGE" "pytest checks... [SUCCESS]" fi } +function pytest_torch_check() { + print_function "STAGE" "pytest torch checks" + if [ ! -f "$PYTEST_PATH" ]; then + echo "For some unknown reasons, the pytest package is not complete." + fi + # Get Python version + PYTHON_VERSION=$(python -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") + echo "Detected Python version: $PYTHON_VERSION" + TEST_DIR="pypaimon/tests/torch_read_test.py" + echo "Running tests for Python $PYTHON_VERSION: pypaimon/tests/torch_read_test.py" + + # the return value of a pipeline is the status of the last command to exit + # with a non-zero status or zero if no command exited with a non-zero status + set -o pipefail + ($PYTEST_PATH $TEST_DIR) 2>&1 | tee -a $LOG_FILE + + PYCODESTYLE_STATUS=$? + if [ $PYCODESTYLE_STATUS -ne 0 ]; then + print_function "STAGE" "pytest checks... [FAILED]" + # Stop the running script. + exit 1; + else + print_function "STAGE" "pytest checks... [SUCCESS]" + fi +} # Mixed tests check - runs Java-Python interoperability tests function mixed_check() { # Get Python version @@ -279,7 +304,7 @@ usage: $0 [options] -l list all checks supported. Examples: ./lint-python.sh => exec all checks. - ./lint-python.sh -e tox,flake8 => exclude checks tox,flake8. + ./lint-python.sh -e flake8 => exclude checks flake8. ./lint-python.sh -i flake8 => include checks flake8. ./lint-python.sh -i mixed => include checks mixed. ./lint-python.sh -l => list all checks supported. diff --git a/paimon-python/dev/requirements.txt b/paimon-python/dev/requirements.txt index 703adec8e8e5..e76827db3e59 100644 --- a/paimon-python/dev/requirements.txt +++ b/paimon-python/dev/requirements.txt @@ -19,27 +19,23 @@ cachetools>=4.2,<6; python_version=="3.6" cachetools>=5,<6; python_version>"3.6" dataclasses>=0.8; python_version < "3.7" -fastavro>=1.4,<2; python_version<"3.9" -fastavro>=1.4,<2; python_version>="3.9" +fastavro>=1.4,<2 fsspec>=2021.10,<2026; python_version<"3.8" fsspec>=2023,<2026; python_version>="3.8" ossfs>=2021.8; python_version<"3.8" ossfs>=2023; python_version>="3.8" -packaging>=21,<26; python_version<"3.8" -packaging>=21,<26; python_version>="3.8" +packaging>=21,<26 pandas>=1.1,<2; python_version < "3.7" pandas>=1.3,<3; python_version >= "3.7" and python_version < "3.9" pandas>=1.5,<3; python_version >= "3.9" polars>=0.9,<1; python_version<"3.8" -polars>=1,<2; python_version=="3.8" -polars>=1,<2; python_version>"3.8" +polars>=1,<2; python_version>="3.8" pyarrow>=6,<7; python_version < "3.8" -pyarrow>=16,<20; python_version >= "3.8" and python_version < "3.13" -pyarrow>=16,<20; python_version >= "3.13" +pyarrow>=16,<20; python_version >= "3.8" +pylance>=0.20,<1; python_version>="3.9" +pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9" pyroaring ray>=2.10,<3 readerwriterlock>=1,<2 -zstandard>=0.19,<1; python_version<"3.9" -zstandard>=0.19,<1; python_version>="3.9" -pylance>=0.20,<1; python_version>="3.9" -pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9" +torch +zstandard>=0.19,<1 \ No newline at end of file diff --git a/paimon-python/pypaimon/read/ray_datasource.py b/paimon-python/pypaimon/read/datasource.py similarity index 67% rename from paimon-python/pypaimon/read/ray_datasource.py rename to paimon-python/pypaimon/read/datasource.py index 905c8bddefdb..835effbf0b17 100644 --- a/paimon-python/pypaimon/read/ray_datasource.py +++ b/paimon-python/pypaimon/read/datasource.py @@ -27,6 +27,7 @@ import pyarrow from packaging.version import parse import ray +import torch from pypaimon.read.split import Split from pypaimon.read.table_read import TableRead @@ -40,8 +41,10 @@ from ray.data.datasource import Datasource +from torch.utils.data import Dataset, IterableDataset -class PaimonDatasource(Datasource): + +class RayDatasource(Datasource): """ Ray Data Datasource implementation for reading Paimon tables. @@ -76,7 +79,7 @@ def estimate_inmemory_data_size(self) -> Optional[int]: @staticmethod def _distribute_splits_into_equal_chunks( - splits: Iterable[Split], n_chunks: int + splits: Iterable[Split], n_chunks: int ) -> List[List[Split]]: """ Implement a greedy knapsack algorithm to distribute the splits across tasks, @@ -88,7 +91,7 @@ def _distribute_splits_into_equal_chunks( # From largest to smallest, add the splits to the smallest chunk one at a time for split in sorted( - splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and s.file_size > 0 else 0, reverse=True + splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and s.file_size > 0 else 0, reverse=True ): smallest_chunk = heapq.heappop(chunk_sizes) chunks[smallest_chunk[1]].append(split) @@ -132,11 +135,11 @@ def get_read_tasks(self, parallelism: int, **kwargs) -> List: # Create a partial function to avoid capturing self in closure # This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107) def _get_read_task( - splits: List[Split], - table=table, - predicate=predicate, - read_type=read_type, - schema=schema, + splits: List[Split], + table=table, + predicate=predicate, + read_type=read_type, + schema=schema, ) -> Iterable[pyarrow.Table]: """Read function that will be executed by Ray workers.""" from pypaimon.read.table_read import TableRead @@ -216,13 +219,128 @@ def _get_read_task( 'read_fn': read_fn, 'metadata': metadata, } - + if parse(ray.__version__) >= parse(RAY_VERSION_SCHEMA_IN_READ_TASK): read_task_kwargs['schema'] = schema - + if parse(ray.__version__) >= parse(RAY_VERSION_PER_TASK_ROW_LIMIT) and per_task_row_limit is not None: read_task_kwargs['per_task_row_limit'] = per_task_row_limit read_tasks.append(ReadTask(**read_task_kwargs)) return read_tasks + + +class TorchDataset(Dataset): + """ + PyTorch Dataset implementation for reading Paimon table data. + + This class enables Paimon table data to be used directly with PyTorch's + training pipeline, allowing for efficient data loading and batching. + """ + + def __init__(self, table_read: TableRead, splits: List[Split]): + """ + Initialize TorchDataset. + + Args: + table_read: TableRead instance for reading data + splits: List of splits to read + """ + arrow_table = table_read.to_arrow(splits) + if arrow_table is None or arrow_table.num_rows == 0: + self._data = [] + else: + self._data = arrow_table.to_pylist() + + def __len__(self) -> int: + """ + Return the total number of rows in the dataset. + + Returns: + Total number of rows across all splits + """ + return len(self._data) + + def __getitem__(self, index: int): + """ + Get a single item from the dataset. + + Args: + index: Index of the item to retrieve + + Returns: + Dictionary containing the row data + """ + if not self._data: + return None + + return self._data[index] + + +class TorchIterDataset(IterableDataset): + """ + PyTorch IterableDataset implementation for reading Paimon table data. + + This class enables streaming data loading from Paimon tables, which is more + memory-efficient for large datasets. Data is read on-the-fly as needed, + rather than loading everything into memory upfront. + """ + + def __init__(self, table_read: TableRead, splits: List[Split]): + """ + Initialize TorchIterDataset. + + Args: + table_read: TableRead instance for reading data + splits: List of splits to read + """ + self.table_read = table_read + self.splits = splits + # Get field names from read_type + self.field_names = [field.name for field in table_read.read_type] + + def __iter__(self): + """ + Iterate over the dataset, converting each OffsetRow to a dictionary. + + Supports multi-worker data loading by partitioning splits across workers. + When num_workers > 0 in DataLoader, each worker will process a subset of splits. + + Yields: + row data of dict type, where keys are column names + """ + worker_info = torch.utils.data.get_worker_info() + + if worker_info is None: + # Single-process data loading, iterate over all splits + splits_to_process = self.splits + else: + # Multi-process data loading, partition splits across workers + worker_id = worker_info.id + num_workers = worker_info.num_workers + + # Calculate start and end indices for this worker + # Distribute splits evenly by slicing + total_splits = len(self.splits) + splits_per_worker = total_splits // num_workers + remainder = total_splits % num_workers + + # Workers with id < remainder get one extra split + if worker_id < remainder: + start_idx = worker_id * (splits_per_worker + 1) + end_idx = start_idx + splits_per_worker + 1 + else: + start_idx = worker_id * splits_per_worker + remainder + end_idx = start_idx + splits_per_worker + + splits_to_process = self.splits[start_idx:end_idx] + + worker_iterator = self.table_read.to_iterator(splits_to_process) + + for offset_row in worker_iterator: + row_dict = {} + for i, field_name in enumerate(self.field_names): + value = offset_row.get_field(i) + row_dict[field_name] = value + yield row_dict diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 953384cc7dc1..7e8dbda41206 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -165,8 +165,8 @@ def to_ray( if override_num_blocks is not None and override_num_blocks < 1: raise ValueError(f"override_num_blocks must be at least 1, got {override_num_blocks}") - from pypaimon.read.ray_datasource import PaimonDatasource - datasource = PaimonDatasource(self, splits) + from pypaimon.read.datasource import RayDatasource + datasource = RayDatasource(self, splits) return ray.data.read_datasource( datasource, ray_remote_args=ray_remote_args, @@ -175,6 +175,17 @@ def to_ray( **read_args ) + def to_torch(self, splits: List[Split], streaming: bool = False) -> "torch.utils.data.Dataset": + """Wrap Paimon table data to PyTorch Dataset.""" + if streaming: + from pypaimon.read.datasource import TorchIterDataset + dataset = TorchIterDataset(self, splits) + return dataset + else: + from pypaimon.read.datasource import TorchDataset + dataset = TorchDataset(self, splits) + return dataset + def _create_split_read(self, split: Split) -> SplitRead: if self.table.is_primary_key_table and not split.raw_convertible: return MergeFileSplitRead( diff --git a/paimon-python/pypaimon/tests/blob_table_test.py b/paimon-python/pypaimon/tests/blob_table_test.py index f87f73ded7e4..9925e21be54d 100755 --- a/paimon-python/pypaimon/tests/blob_table_test.py +++ b/paimon-python/pypaimon/tests/blob_table_test.py @@ -2644,7 +2644,7 @@ def write_blob_data(thread_id, start_id): # Create and start multiple threads threads = [] - num_threads = 100 + num_threads = 10 for i in range(num_threads): thread = threading.Thread( target=write_blob_data, diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py b/paimon-python/pypaimon/tests/reader_append_only_test.py index d65658ef5c55..adb0ff4f2562 100644 --- a/paimon-python/pypaimon/tests/reader_append_only_test.py +++ b/paimon-python/pypaimon/tests/reader_append_only_test.py @@ -438,44 +438,6 @@ def test_incremental_read_multi_snapshots(self): }, schema=self.pa_schema).sort_by('user_id') self.assertEqual(expected, actual) - def _write_test_table(self, table): - write_builder = table.new_batch_write_builder() - - # first write - table_write = write_builder.new_write() - table_commit = write_builder.new_commit() - data1 = { - 'user_id': [1, 2, 3, 4], - 'item_id': [1001, 1002, 1003, 1004], - 'behavior': ['a', 'b', 'c', None], - 'dt': ['p1', 'p1', 'p2', 'p1'], - } - pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) - table_write.write_arrow(pa_table) - table_commit.commit(table_write.prepare_commit()) - table_write.close() - table_commit.close() - - # second write - table_write = write_builder.new_write() - table_commit = write_builder.new_commit() - data2 = { - 'user_id': [5, 6, 7, 8], - 'item_id': [1005, 1006, 1007, 1008], - 'behavior': ['e', 'f', 'g', 'h'], - 'dt': ['p2', 'p1', 'p2', 'p2'], - } - pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) - table_write.write_arrow(pa_table) - table_commit.commit(table_write.prepare_commit()) - table_write.close() - table_commit.close() - - def _read_test_table(self, read_builder): - table_read = read_builder.new_read() - splits = read_builder.new_scan().plan().splits() - return table_read.to_arrow(splits) - def test_concurrent_writes_with_retry(self): """Test concurrent writes to verify retry mechanism works correctly.""" import threading @@ -529,7 +491,7 @@ def write_data(thread_id, start_user_id): # Create and start multiple threads threads = [] - num_threads = 100 + num_threads = 10 for i in range(num_threads): thread = threading.Thread( target=write_data, @@ -576,3 +538,41 @@ def write_data(thread_id, start_user_id): f"got {latest_snapshot.id}") print(f"✓ Iteration {test_iteration + 1}/{iter_num} completed successfully") + + def _write_test_table(self, table): + write_builder = table.new_batch_write_builder() + + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['a', 'b', 'c', None], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p2'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + def _read_test_table(self, read_builder): + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + return table_read.to_arrow(splits) diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py b/paimon-python/pypaimon/tests/reader_primary_key_test.py index 731203385d2a..c22346afe739 100644 --- a/paimon-python/pypaimon/tests/reader_primary_key_test.py +++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py @@ -479,7 +479,7 @@ def write_data(thread_id, start_user_id): # Create and start multiple threads threads = [] - num_threads = 100 + num_threads = 10 for i in range(num_threads): thread = threading.Thread( target=write_data, diff --git a/paimon-python/pypaimon/tests/torch_read_test.py b/paimon-python/pypaimon/tests/torch_read_test.py new file mode 100644 index 000000000000..b6862c6cb127 --- /dev/null +++ b/paimon-python/pypaimon/tests/torch_read_test.py @@ -0,0 +1,635 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa +from parameterized import parameterized +from torch.utils.data import DataLoader + +from pypaimon import CatalogFactory, Schema + +from pypaimon.table.file_store_table import FileStoreTable + + +class TorchReadTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({ + 'warehouse': cls.warehouse + }) + cls.catalog.create_database('default', True) + + cls.pa_schema = pa.schema([ + ('user_id', pa.int32()), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()) + ]) + cls.expected = pa.Table.from_pydict({ + 'user_id': [1, 2, 3, 4, 5, 6, 7, 8], + 'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008], + 'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'], + 'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'], + }, schema=cls.pa_schema) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + @parameterized.expand([True, False]) + def test_torch_read(self, is_streaming: bool = False): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id']) + self.catalog.create_table(f'default.test_torch_read_{str(is_streaming)}', schema, False) + table = self.catalog.get_table(f'default.test_torch_read_{str(is_streaming)}') + self._write_test_table(table) + + read_builder = table.new_read_builder().with_projection(['user_id', 'behavior']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=is_streaming) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=2, + shuffle=False + ) + + # Collect all data from dataloader + all_user_ids = [] + all_behaviors = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + behaviors = batch_data['behavior'] + all_user_ids.extend(user_ids) + all_behaviors.extend(behaviors) + + # Sort by user_id for comparison + sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: x[0]) + sorted_user_ids = [x[0] for x in sorted_data] + sorted_behaviors = [x[1] for x in sorted_data] + + # Expected data (sorted by user_id) + expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8] + expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] + + # Verify results + self.assertEqual(sorted_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {sorted_user_ids}") + self.assertEqual(sorted_behaviors, expected_behaviors, + f"Behaviors mismatch. Expected {expected_behaviors}, got {sorted_behaviors}") + + print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with correct data") + + def test_blob_torch_read(self): + """Test end-to-end blob functionality using blob descriptors.""" + import random + from pypaimon import Schema + from pypaimon.table.row.blob import BlobDescriptor + + # Create schema with blob column + pa_schema = pa.schema([ + ('id', pa.int32()), + ('picture', pa.large_binary()), + ]) + + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob-as-descriptor': 'true' + } + ) + + # Create table + self.catalog.create_table('default.test_blob_torch_read', schema, False) + table: FileStoreTable = self.catalog.get_table('default.test_blob_torch_read') + + # Create test blob data (1MB) + blob_data = bytearray(1024 * 1024) + random.seed(42) # For reproducible tests + for i in range(len(blob_data)): + blob_data[i] = random.randint(0, 255) + blob_data = bytes(blob_data) + + # Create external blob file + external_blob_path = os.path.join(self.tempdir, 'external_blob') + with open(external_blob_path, 'wb') as f: + f.write(blob_data) + + # Create blob descriptor pointing to external file + blob_descriptor = BlobDescriptor(external_blob_path, 0, len(blob_data)) + + # Create test data with blob descriptor + test_data = pa.Table.from_pydict({ + 'id': [1], + 'picture': [blob_descriptor.serialize()] + }, schema=pa_schema) + + # Write data using table API + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.write_arrow(test_data) + + # Commit the data + commit_messages = writer.prepare_commit() + commit = write_builder.new_commit() + commit.commit(commit_messages) + + # Read data back + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + result = table_read.to_torch(table_scan.plan().splits()) + + dataloader = DataLoader( + result, + batch_size=1, + num_workers=0, + shuffle=False + ) + + # Collect and verify data + all_ids = [] + all_pictures = [] + for batch_idx, batch_data in enumerate(dataloader): + ids = batch_data['id'].tolist() + pictures = batch_data['picture'] + all_ids.extend(ids) + all_pictures.extend(pictures) + + # Verify results + self.assertEqual(len(all_ids), 1, "Should have exactly 1 row") + self.assertEqual(all_ids[0], 1, "ID should be 1") + + # Verify blob descriptor + picture_bytes = all_pictures[0] + self.assertIsInstance(picture_bytes, bytes, "Picture should be bytes") + + # Deserialize and verify blob descriptor + from pypaimon.table.row.blob import BlobDescriptor + read_blob_descriptor = BlobDescriptor.deserialize(picture_bytes) + self.assertEqual(read_blob_descriptor.length, len(blob_data), + f"Blob length mismatch. Expected {len(blob_data)}, got {read_blob_descriptor.length}") + self.assertGreaterEqual(read_blob_descriptor.offset, 0, "Offset should be non-negative") + + # Read and verify blob content + from pypaimon.common.uri_reader import UriReaderFactory + from pypaimon.common.options.config import CatalogOptions + from pypaimon.table.row.blob import Blob + + catalog_options = {CatalogOptions.WAREHOUSE.key(): self.warehouse} + uri_reader_factory = UriReaderFactory(catalog_options) + uri_reader = uri_reader_factory.create(read_blob_descriptor.uri) + blob = Blob.from_descriptor(uri_reader, read_blob_descriptor) + + # Verify blob data matches original + read_blob_data = blob.to_data() + self.assertEqual(len(read_blob_data), len(blob_data), + f"Blob data length mismatch. Expected {len(blob_data)}, got {len(read_blob_data)}") + self.assertEqual(read_blob_data, blob_data, "Blob data content should match original") + + print(f"✓ Blob torch read test passed: Successfully read and verified {len(blob_data)} bytes of blob data") + + def test_torch_read_pk_table(self): + """Test torch read with primary key table.""" + # Create PK table with user_id as primary key and behavior as partition key + schema = Schema.from_pyarrow_schema( + self.pa_schema, + primary_keys=['user_id', 'behavior'], + partition_keys=['behavior'], + options={'bucket': 2} + ) + self.catalog.create_table('default.test_pk_table', schema, False) + table = self.catalog.get_table('default.test_pk_table') + self._write_test_table(table) + + read_builder = table.new_read_builder().with_projection(['user_id', 'behavior']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=3, + shuffle=False + ) + + # Collect all data from dataloader + all_user_ids = [] + all_behaviors = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + behaviors = batch_data['behavior'] + all_user_ids.extend(user_ids) + all_behaviors.extend(behaviors) + + # Sort by user_id for comparison + sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: x[0]) + sorted_user_ids = [x[0] for x in sorted_data] + sorted_behaviors = [x[1] for x in sorted_data] + + # Expected data (sorted by user_id) + expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8] + expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] + + # Verify results + self.assertEqual(sorted_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {sorted_user_ids}") + self.assertEqual(sorted_behaviors, expected_behaviors, + f"Behaviors mismatch. Expected {expected_behaviors}, got {sorted_behaviors}") + + print(f"✓ PK table test passed: Successfully read {len(all_user_ids)} rows with correct data") + + def test_torch_read_large_append_table(self): + """Test torch read with large data volume on append-only table.""" + # Create append-only table + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt']) + self.catalog.create_table('default.test_large_append', schema, False) + table = self.catalog.get_table('default.test_large_append') + + # Write large amount of data + write_builder = table.new_batch_write_builder() + total_rows = 100000 + batch_size = 10000 + num_batches = total_rows // batch_size + + print(f"\n{'=' * 60}") + print(f"Writing {total_rows} rows to append-only table...") + print(f"{'=' * 60}") + + for batch_idx in range(num_batches): + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + start_id = batch_idx * batch_size + 1 + end_id = start_id + batch_size + + data = { + 'user_id': list(range(start_id, end_id)), + 'item_id': [1000 + i for i in range(start_id, end_id)], + 'behavior': [chr(ord('a') + (i % 26)) for i in range(batch_size)], + 'dt': [f'p{i % 4}' for i in range(batch_size)], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + if (batch_idx + 1) % 2 == 0: + print(f" Written {(batch_idx + 1) * batch_size} rows...") + + # Read data using torch + print(f"\nReading {total_rows} rows using Torch DataLoader...") + + read_builder = table.new_read_builder().with_projection(['user_id', 'behavior']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + + print(f"Total splits: {len(splits)}") + + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=1000, + num_workers=4, + shuffle=False + ) + + # Collect all data + all_user_ids = [] + batch_count = 0 + for batch_idx, batch_data in enumerate(dataloader): + batch_count += 1 + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + if (batch_idx + 1) % 20 == 0: + print(f" Read {len(all_user_ids)} rows...") + + all_user_ids.sort() + # Verify data + self.assertEqual(len(all_user_ids), total_rows, + f"Row count mismatch. Expected {total_rows}, got {len(all_user_ids)}") + self.assertEqual(all_user_ids, list(range(1, total_rows + 1)), + f"Row count mismatch. Expected {total_rows}, got {len(all_user_ids)}") + print(f"\n{'=' * 60}") + print("✓ Large append table test passed!") + print(f" Total rows: {total_rows}") + print(f" Total batches: {batch_count}") + print(f"{'=' * 60}\n") + + def test_torch_read_large_pk_table(self): + """Test torch read with large data volume on primary key table.""" + + # Create PK table + schema = Schema.from_pyarrow_schema( + self.pa_schema, + primary_keys=['user_id'], + partition_keys=['dt'], + options={'bucket': '4'} + ) + self.catalog.create_table('default.test_large_pk', schema, False) + table = self.catalog.get_table('default.test_large_pk') + + # Write large amount of data + write_builder = table.new_batch_write_builder() + total_rows = 100000 + batch_size = 10000 + num_batches = total_rows // batch_size + + print(f"\n{'=' * 60}") + print(f"Writing {total_rows} rows to PK table...") + print(f"{'=' * 60}") + + for batch_idx in range(num_batches): + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + start_id = batch_idx * batch_size + 1 + end_id = start_id + batch_size + + data = { + 'user_id': list(range(start_id, end_id)), + 'item_id': [1000 + i for i in range(start_id, end_id)], + 'behavior': [chr(ord('a') + (i % 26)) for i in range(batch_size)], + 'dt': [f'p{i % 4}' for i in range(batch_size)], + } + pa_table = pa.Table.from_pydict(data, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + if (batch_idx + 1) % 2 == 0: + print(f" Written {(batch_idx + 1) * batch_size} rows...") + + # Read data using torch + print(f"\nReading {total_rows} rows using Torch DataLoader...") + + read_builder = table.new_read_builder() + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + + print(f"Total splits: {len(splits)}") + + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=1000, + num_workers=8, + shuffle=False + ) + + # Collect all data + all_user_ids = [] + batch_count = 0 + for batch_idx, batch_data in enumerate(dataloader): + batch_count += 1 + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + if (batch_idx + 1) % 20 == 0: + print(f" Read {len(all_user_ids)} rows...") + + all_user_ids.sort() + # Verify data + self.assertEqual(len(all_user_ids), total_rows, + f"Row count mismatch. Expected {total_rows}, got {len(all_user_ids)}") + + self.assertEqual(all_user_ids, list(range(1, total_rows + 1)), + f"Row count mismatch. Expected {total_rows}, got {len(all_user_ids)}") + + print(f"\n{'=' * 60}") + print("✓ Large PK table test passed!") + print(f" Total rows: {total_rows}") + print(f" Total batches: {batch_count}") + print(" Primary key uniqueness: ✓") + print(f"{'=' * 60}\n") + + def test_torch_read_with_predicate(self): + """Test torch read with predicate filtering.""" + + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id']) + self.catalog.create_table('default.test_predicate', schema, False) + table = self.catalog.get_table('default.test_predicate') + self._write_test_table(table) + + # Test case 1: Filter by user_id > 4 + print(f"\n{'=' * 60}") + print("Test Case 1: user_id > 4") + print(f"{'=' * 60}") + predicate_builder = table.new_read_builder().new_predicate_builder() + + predicate = predicate_builder.greater_than('user_id', 4) + read_builder = table.new_read_builder().with_filter(predicate) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + all_user_ids.sort() + expected_user_ids = [5, 6, 7, 8] + self.assertEqual(all_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {all_user_ids}") + print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}") + + # Test case 2: Filter by user_id <= 3 + print(f"\n{'=' * 60}") + print("Test Case 2: user_id <= 3") + print(f"{'=' * 60}") + + predicate = predicate_builder.less_or_equal('user_id', 3) + read_builder = table.new_read_builder().with_filter(predicate) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + all_user_ids.sort() + expected_user_ids = [1, 2, 3] + self.assertEqual(all_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {all_user_ids}") + print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}") + + # Test case 3: Filter by behavior = 'a' + print(f"\n{'=' * 60}") + print("Test Case 3: behavior = 'a'") + print(f"{'=' * 60}") + + predicate = predicate_builder.equal('behavior', 'a') + read_builder = table.new_read_builder().with_filter(predicate) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + all_behaviors = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + behaviors = batch_data['behavior'] + all_user_ids.extend(user_ids) + all_behaviors.extend(behaviors) + + expected_user_ids = [1] + expected_behaviors = ['a'] + self.assertEqual(all_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {all_user_ids}") + self.assertEqual(all_behaviors, expected_behaviors, + f"Behaviors mismatch. Expected {expected_behaviors}, got {all_behaviors}") + print(f"✓ Filtered {len(all_user_ids)} rows: user_ids={all_user_ids}, behaviors={all_behaviors}") + + # Test case 4: Filter by user_id IN (2, 4, 6) + print(f"\n{'=' * 60}") + print("Test Case 4: user_id IN (2, 4, 6)") + print(f"{'=' * 60}") + + predicate = predicate_builder.is_in('user_id', [2, 4, 6]) + read_builder = table.new_read_builder().with_filter(predicate) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + all_user_ids.sort() + expected_user_ids = [2, 4, 6] + self.assertEqual(all_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {all_user_ids}") + print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}") + + # Test case 5: Combined filter (user_id > 2 AND user_id < 7) + print(f"\n{'=' * 60}") + print("Test Case 5: user_id > 2 AND user_id < 7") + print(f"{'=' * 60}") + + predicate1 = predicate_builder.greater_than('user_id', 2) + predicate2 = predicate_builder.less_than('user_id', 7) + combined_predicate = predicate_builder.and_predicates([predicate1, predicate2]) + read_builder = table.new_read_builder().with_filter(combined_predicate) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + dataset = table_read.to_torch(splits, streaming=True) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + for batch_idx, batch_data in enumerate(dataloader): + user_ids = batch_data['user_id'].tolist() + all_user_ids.extend(user_ids) + + all_user_ids.sort() + expected_user_ids = [3, 4, 5, 6] + self.assertEqual(all_user_ids, expected_user_ids, + f"User IDs mismatch. Expected {expected_user_ids}, got {all_user_ids}") + print(f"✓ Filtered {len(all_user_ids)} rows: {all_user_ids}") + + print(f"\n{'=' * 60}") + print("✓ All predicate test cases passed!") + print(f"{'=' * 60}\n") + + def _write_test_table(self, table): + write_builder = table.new_batch_write_builder() + + # first write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'user_id': [1, 2, 3, 4], + 'item_id': [1001, 1002, 1003, 1004], + 'behavior': ['a', 'b', 'c', 'd'], + 'dt': ['p1', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # second write + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data2 = { + 'user_id': [5, 6, 7, 8], + 'item_id': [1005, 1006, 1007, 1008], + 'behavior': ['e', 'f', 'g', 'h'], + 'dt': ['p2', 'p1', 'p2', 'p1'], + } + pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + def _read_test_table(self, read_builder): + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + return table_read.to_arrow(splits)