Skip to content

Commit 280b65b

Browse files
authored
feat: Add Parquet writer option autodetection
1 parent e48c8bb commit 280b65b

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

python/datafusion/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
SessionContext,
4747
SQLOptions,
4848
)
49-
from .dataframe import DataFrame
49+
from .dataframe import (
50+
DataFrame,
51+
ParquetColumnOptions,
52+
ParquetWriterOptions,
53+
)
5054
from .expr import (
5155
Expr,
5256
WindowFrame,
@@ -80,6 +84,8 @@
8084
"ExecutionPlan",
8185
"Expr",
8286
"LogicalPlan",
87+
"ParquetColumnOptions",
88+
"ParquetWriterOptions",
8389
"RecordBatch",
8490
"RecordBatchStream",
8591
"RuntimeEnvBuilder",

python/datafusion/dataframe.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from datafusion._internal import DataFrame as DataFrameInternal
5454
from datafusion._internal import expr as expr_internal
5555

56+
from dataclasses import dataclass
5657
from enum import Enum
5758

5859

@@ -114,6 +115,19 @@ def get_default_level(self) -> Optional[int]:
114115
return None
115116

116117

118+
@dataclass
119+
class ParquetWriterOptions:
120+
"""Options for writing Parquet files."""
121+
122+
compression: str | Compression = Compression.ZSTD
123+
compression_level: int | None = None
124+
125+
126+
@dataclass
127+
class ParquetColumnOptions:
128+
"""Placeholder for column-specific options."""
129+
130+
117131
class DataFrame:
118132
"""Two dimensional table representation of data.
119133
@@ -704,7 +718,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
704718
def write_parquet(
705719
self,
706720
path: str | pathlib.Path,
707-
compression: Union[str, Compression] = Compression.ZSTD,
721+
compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD,
708722
compression_level: int | None = None,
709723
) -> None:
710724
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
@@ -725,7 +739,13 @@ def write_parquet(
725739
recommended range is 1 to 22, with the default being 4. Higher levels
726740
provide better compression but slower speed.
727741
"""
728-
# Convert string to Compression enum if necessary
742+
if isinstance(compression, ParquetWriterOptions):
743+
if compression_level is not None:
744+
msg = "compression_level should be None when using ParquetWriterOptions"
745+
raise ValueError(msg)
746+
self.write_parquet_with_options(path, compression)
747+
return
748+
729749
if isinstance(compression, str):
730750
compression = Compression.from_str(compression)
731751

@@ -737,6 +757,28 @@ def write_parquet(
737757

738758
self.df.write_parquet(str(path), compression.value, compression_level)
739759

760+
def write_parquet_with_options(
761+
self, path: str | pathlib.Path, options: ParquetWriterOptions
762+
) -> None:
763+
"""Execute the :py:class:`DataFrame` and write the results to Parquet.
764+
765+
Args:
766+
path: Destination path.
767+
options: Parquet writer options.
768+
"""
769+
compression = options.compression
770+
if isinstance(compression, str):
771+
compression = Compression.from_str(compression)
772+
773+
level = options.compression_level
774+
if (
775+
compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}
776+
and level is None
777+
):
778+
level = compression.get_default_level()
779+
780+
self.df.write_parquet(str(path), compression.value, level)
781+
740782
def write_json(self, path: str | pathlib.Path) -> None:
741783
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
742784

python/tests/test_dataframe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import pytest
2828
from datafusion import (
2929
DataFrame,
30+
ParquetWriterOptions,
3031
SessionContext,
3132
WindowFrame,
3233
column,
@@ -1632,6 +1633,22 @@ def test_write_compressed_parquet_default_compression_level(df, tmp_path, compre
16321633
df.write_parquet(str(path), compression=compression)
16331634

16341635

1636+
def test_write_parquet_options(df, tmp_path):
1637+
options = ParquetWriterOptions(compression="gzip", compression_level=6)
1638+
df.write_parquet(str(tmp_path), options)
1639+
1640+
result = pq.read_table(str(tmp_path)).to_pydict()
1641+
expected = df.to_pydict()
1642+
1643+
assert result == expected
1644+
1645+
1646+
def test_write_parquet_options_error(df, tmp_path):
1647+
options = ParquetWriterOptions(compression="gzip", compression_level=6)
1648+
with pytest.raises(ValueError):
1649+
df.write_parquet(str(tmp_path), options, compression_level=1)
1650+
1651+
16351652
def test_dataframe_export(df) -> None:
16361653
# Guarantees that we have the canonical implementation
16371654
# reading our dataframe export

0 commit comments

Comments
 (0)