diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 4f7700251..0d245aada 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -46,7 +46,11 @@ SessionContext, SQLOptions, ) -from .dataframe import DataFrame +from .dataframe import ( + DataFrame, + ParquetColumnOptions, + ParquetWriterOptions, +) from .expr import ( Expr, WindowFrame, @@ -80,6 +84,8 @@ "ExecutionPlan", "Expr", "LogicalPlan", + "ParquetColumnOptions", + "ParquetWriterOptions", "RecordBatch", "RecordBatchStream", "RuntimeEnvBuilder", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index a1df7e080..ae3ecc653 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -53,6 +53,7 @@ from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import expr as expr_internal +from dataclasses import dataclass from enum import Enum @@ -114,6 +115,19 @@ def get_default_level(self) -> Optional[int]: return None +@dataclass +class ParquetWriterOptions: + """Options for writing Parquet files.""" + + compression: str | Compression = Compression.ZSTD + compression_level: int | None = None + + +@dataclass +class ParquetColumnOptions: + """Placeholder for column-specific options.""" + + class DataFrame: """Two dimensional table representation of data. @@ -704,7 +718,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None def write_parquet( self, path: str | pathlib.Path, - compression: Union[str, Compression] = Compression.ZSTD, + compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD, compression_level: int | None = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. @@ -725,7 +739,13 @@ def write_parquet( recommended range is 1 to 22, with the default being 4. Higher levels provide better compression but slower speed. """ - # Convert string to Compression enum if necessary + if isinstance(compression, ParquetWriterOptions): + if compression_level is not None: + msg = "compression_level should be None when using ParquetWriterOptions" + raise ValueError(msg) + self.write_parquet_with_options(path, compression) + return + if isinstance(compression, str): compression = Compression.from_str(compression) @@ -737,6 +757,28 @@ def write_parquet( self.df.write_parquet(str(path), compression.value, compression_level) + def write_parquet_with_options( + self, path: str | pathlib.Path, options: ParquetWriterOptions + ) -> None: + """Execute the :py:class:`DataFrame` and write the results to Parquet. + + Args: + path: Destination path. + options: Parquet writer options. + """ + compression = options.compression + if isinstance(compression, str): + compression = Compression.from_str(compression) + + level = options.compression_level + if ( + compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD} + and level is None + ): + level = compression.get_default_level() + + self.df.write_parquet(str(path), compression.value, level) + def write_json(self, path: str | pathlib.Path) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 64220ce9c..ca5c2be4f 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -27,6 +27,7 @@ import pytest from datafusion import ( DataFrame, + ParquetWriterOptions, SessionContext, WindowFrame, column, @@ -1632,6 +1633,22 @@ def test_write_compressed_parquet_default_compression_level(df, tmp_path, compre df.write_parquet(str(path), compression=compression) +def test_write_parquet_options(df, tmp_path): + options = ParquetWriterOptions(compression="gzip", compression_level=6) + df.write_parquet(str(tmp_path), options) + + result = pq.read_table(str(tmp_path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +def test_write_parquet_options_error(df, tmp_path): + options = ParquetWriterOptions(compression="gzip", compression_level=6) + with pytest.raises(ValueError): + df.write_parquet(str(tmp_path), options, compression_level=1) + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export