Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
SessionContext,
SQLOptions,
)
from .dataframe import DataFrame
from .dataframe import (
DataFrame,
ParquetColumnOptions,
ParquetWriterOptions,
)
from .expr import (
Expr,
WindowFrame,
Expand Down Expand Up @@ -80,6 +84,8 @@
"ExecutionPlan",
"Expr",
"LogicalPlan",
"ParquetColumnOptions",
"ParquetWriterOptions",
"RecordBatch",
"RecordBatchStream",
"RuntimeEnvBuilder",
Expand Down
46 changes: 44 additions & 2 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from datafusion._internal import DataFrame as DataFrameInternal
from datafusion._internal import expr as expr_internal

from dataclasses import dataclass
from enum import Enum


Expand Down Expand Up @@ -114,6 +115,19 @@ def get_default_level(self) -> Optional[int]:
return None


@dataclass
class ParquetWriterOptions:
"""Options for writing Parquet files."""

compression: str | Compression = Compression.ZSTD
compression_level: int | None = None


@dataclass
class ParquetColumnOptions:
"""Placeholder for column-specific options."""


class DataFrame:
"""Two dimensional table representation of data.

Expand Down Expand Up @@ -704,7 +718,7 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
def write_parquet(
self,
path: str | pathlib.Path,
compression: Union[str, Compression] = Compression.ZSTD,
compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD,
compression_level: int | None = None,
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
Expand All @@ -725,7 +739,13 @@ def write_parquet(
recommended range is 1 to 22, with the default being 4. Higher levels
provide better compression but slower speed.
"""
# Convert string to Compression enum if necessary
if isinstance(compression, ParquetWriterOptions):
if compression_level is not None:
msg = "compression_level should be None when using ParquetWriterOptions"
raise ValueError(msg)
self.write_parquet_with_options(path, compression)
return

if isinstance(compression, str):
compression = Compression.from_str(compression)

Expand All @@ -737,6 +757,28 @@ def write_parquet(

self.df.write_parquet(str(path), compression.value, compression_level)

def write_parquet_with_options(
self, path: str | pathlib.Path, options: ParquetWriterOptions
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to Parquet.

Args:
path: Destination path.
options: Parquet writer options.
"""
compression = options.compression
if isinstance(compression, str):
compression = Compression.from_str(compression)

level = options.compression_level
if (
compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}
and level is None
):
level = compression.get_default_level()

self.df.write_parquet(str(path), compression.value, level)

def write_json(self, path: str | pathlib.Path) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.

Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pytest
from datafusion import (
DataFrame,
ParquetWriterOptions,
SessionContext,
WindowFrame,
column,
Expand Down Expand Up @@ -1632,6 +1633,22 @@ def test_write_compressed_parquet_default_compression_level(df, tmp_path, compre
df.write_parquet(str(path), compression=compression)


def test_write_parquet_options(df, tmp_path):
options = ParquetWriterOptions(compression="gzip", compression_level=6)
df.write_parquet(str(tmp_path), options)

result = pq.read_table(str(tmp_path)).to_pydict()
expected = df.to_pydict()

assert result == expected


def test_write_parquet_options_error(df, tmp_path):
options = ParquetWriterOptions(compression="gzip", compression_level=6)
with pytest.raises(ValueError):
df.write_parquet(str(tmp_path), options, compression_level=1)


def test_dataframe_export(df) -> None:
# Guarantees that we have the canonical implementation
# reading our dataframe export
Expand Down
Loading