Skip to content

Commit aa23651

Browse files
committed
refactor: enhance type handling in SessionContext and add pyarrow type helpers
1 parent f084e56 commit aa23651

File tree

2 files changed

+92
-14
lines changed

2 files changed

+92
-14
lines changed

python/datafusion/context.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import warnings
2323
from typing import TYPE_CHECKING, Any, Protocol
2424

25-
import pyarrow as pa
25+
from datafusion.types import ensure_pyarrow_type
26+
from datafusion.common import DataTypeMap
2627

2728
try:
2829
from warnings import deprecated # Python 3.13+
@@ -45,6 +46,7 @@
4546

4647
import pandas as pd
4748
import polars as pl
49+
import pyarrow as pa
4850

4951
from datafusion.plan import ExecutionPlan, LogicalPlan
5052

@@ -550,7 +552,7 @@ def register_listing_table(
550552
self,
551553
name: str,
552554
path: str | pathlib.Path,
553-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
555+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
554556
file_extension: str = ".parquet",
555557
schema: pa.Schema | None = None,
556558
file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -803,7 +805,7 @@ def register_parquet(
803805
self,
804806
name: str,
805807
path: str | pathlib.Path,
806-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
808+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
807809
parquet_pruning: bool = True,
808810
file_extension: str = ".parquet",
809811
skip_metadata: bool = True,
@@ -895,7 +897,7 @@ def register_json(
895897
schema: pa.Schema | None = None,
896898
schema_infer_max_records: int = 1000,
897899
file_extension: str = ".json",
898-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
900+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
899901
file_compression_type: str | None = None,
900902
) -> None:
901903
"""Register a JSON file as a table.
@@ -933,7 +935,7 @@ def register_avro(
933935
path: str | pathlib.Path,
934936
schema: pa.Schema | None = None,
935937
file_extension: str = ".avro",
936-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
938+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
937939
) -> None:
938940
"""Register an Avro file as a table.
939941
@@ -1009,7 +1011,7 @@ def read_json(
10091011
schema: pa.Schema | None = None,
10101012
schema_infer_max_records: int = 1000,
10111013
file_extension: str = ".json",
1012-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1014+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
10131015
file_compression_type: str | None = None,
10141016
) -> DataFrame:
10151017
"""Read a line-delimited JSON data source.
@@ -1049,7 +1051,7 @@ def read_csv(
10491051
delimiter: str = ",",
10501052
schema_infer_max_records: int = 1000,
10511053
file_extension: str = ".csv",
1052-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1054+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
10531055
file_compression_type: str | None = None,
10541056
) -> DataFrame:
10551057
"""Read a CSV data source.
@@ -1094,7 +1096,7 @@ def read_csv(
10941096
def read_parquet(
10951097
self,
10961098
path: str | pathlib.Path,
1097-
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1099+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
10981100
parquet_pruning: bool = True,
10991101
file_extension: str = ".parquet",
11001102
skip_metadata: bool = True,
@@ -1145,7 +1147,7 @@ def read_avro(
11451147
self,
11461148
path: str | pathlib.Path,
11471149
schema: pa.Schema | None = None,
1148-
file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
1150+
file_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None,
11491151
file_extension: str = ".avro",
11501152
) -> DataFrame:
11511153
"""Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1181,26 +1183,27 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11811183

11821184
@staticmethod
11831185
def _convert_table_partition_cols(
1184-
table_partition_cols: list[tuple[str, str | pa.DataType]],
1185-
) -> list[tuple[str, pa.DataType]]:
1186+
table_partition_cols: list[tuple[str, str | DataTypeMap | Any]],
1187+
) -> list[tuple[str, Any]]:
11861188
warn = False
11871189
converted_table_partition_cols = []
11881190

11891191
for col, data_type in table_partition_cols:
11901192
if isinstance(data_type, str):
11911193
warn = True
11921194
if data_type == "string":
1193-
converted_data_type = pa.string()
1195+
mapped = DataTypeMap.py_map_from_arrow_type_str("utf8")
11941196
elif data_type == "int":
1195-
converted_data_type = pa.int32()
1197+
mapped = DataTypeMap.py_map_from_arrow_type_str("int32")
11961198
else:
11971199
message = (
11981200
f"Unsupported literal data type '{data_type}' for partition "
11991201
"column. Supported types are 'string' and 'int'"
12001202
)
12011203
raise ValueError(message)
1204+
converted_data_type = ensure_pyarrow_type(mapped)
12021205
else:
1203-
converted_data_type = data_type
1206+
converted_data_type = ensure_pyarrow_type(data_type)
12041207

12051208
converted_table_partition_cols.append((col, converted_data_type))
12061209

python/datafusion/types.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Internal Arrow type helpers with optional PyArrow conversion."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
try: # pragma: no cover - optional dependency
8+
import pyarrow as pa
9+
except Exception: # pragma: no cover - optional dependency
10+
pa = None # type: ignore
11+
12+
from datafusion.common import DataTypeMap
13+
14+
_PYARROW_TYPE_FACTORIES = {
15+
"Null": lambda: pa.null() if pa else None,
16+
"Boolean": lambda: pa.bool_() if pa else None,
17+
"Int8": lambda: pa.int8() if pa else None,
18+
"Int16": lambda: pa.int16() if pa else None,
19+
"Int32": lambda: pa.int32() if pa else None,
20+
"Int64": lambda: pa.int64() if pa else None,
21+
"UInt8": lambda: pa.uint8() if pa else None,
22+
"UInt16": lambda: pa.uint16() if pa else None,
23+
"UInt32": lambda: pa.uint32() if pa else None,
24+
"UInt64": lambda: pa.uint64() if pa else None,
25+
"Float16": lambda: pa.float16() if pa else None,
26+
"Float32": lambda: pa.float32() if pa else None,
27+
"Float64": lambda: pa.float64() if pa else None,
28+
"Utf8": lambda: pa.string() if pa else None,
29+
}
30+
31+
32+
def pyarrow_available() -> bool:
33+
"""Return ``True`` if :mod:`pyarrow` can be imported."""
34+
35+
return pa is not None
36+
37+
38+
def to_pyarrow(data_type: DataTypeMap) -> "pa.DataType":
39+
"""Convert a :class:`DataTypeMap` to a :mod:`pyarrow` data type.
40+
41+
Raises ``ModuleNotFoundError`` if :mod:`pyarrow` is not installed.
42+
"""
43+
44+
if pa is None: # pragma: no cover - optional dependency
45+
raise ModuleNotFoundError("pyarrow is not installed")
46+
name = str(data_type.arrow_type)
47+
factory = _PYARROW_TYPE_FACTORIES.get(name)
48+
if factory is None:
49+
msg = f"Conversion to pyarrow for '{name}' is not implemented"
50+
raise NotImplementedError(msg)
51+
return factory()
52+
53+
54+
def from_pyarrow(pa_type: "pa.DataType") -> DataTypeMap:
55+
"""Convert a :mod:`pyarrow` data type to :class:`DataTypeMap`.
56+
57+
Raises ``ModuleNotFoundError`` if :mod:`pyarrow` is not installed.
58+
"""
59+
60+
if pa is None: # pragma: no cover - optional dependency
61+
raise ModuleNotFoundError("pyarrow is not installed")
62+
return DataTypeMap.py_map_from_arrow_type_str(str(pa_type))
63+
64+
65+
def ensure_pyarrow_type(value: DataTypeMap | Any) -> Any:
66+
"""Ensure ``value`` is a :mod:`pyarrow` data type if available.
67+
68+
If ``value`` is a :class:`DataTypeMap` and :mod:`pyarrow` is installed,
69+
it will be converted to the corresponding :mod:`pyarrow` data type.
70+
Otherwise ``value`` is returned unchanged.
71+
"""
72+
73+
if isinstance(value, DataTypeMap):
74+
return to_pyarrow(value) if pyarrow_available() else value
75+
return value

0 commit comments

Comments
 (0)