Skip to content

Commit 41e0fed

Browse files
committed
support for old logic
dasdas
1 parent a6500a5 commit 41e0fed

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

python/datafusion/context.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
from __future__ import annotations
2121

22+
import warnings
2223
from typing import TYPE_CHECKING, Any, Protocol
2324

25+
import pyarrow as pa
26+
2427
try:
2528
from warnings import deprecated # Python 3.13+
2629
except ImportError:
@@ -535,7 +538,7 @@ def register_listing_table(
535538
self,
536539
name: str,
537540
path: str | pathlib.Path,
538-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
541+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
539542
file_extension: str = ".parquet",
540543
schema: pa.Schema | None = None,
541544
file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -556,6 +559,7 @@ def register_listing_table(
556559
"""
557560
if table_partition_cols is None:
558561
table_partition_cols = []
562+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
559563
file_sort_order_raw = (
560564
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
561565
if file_sort_order is not None
@@ -774,7 +778,7 @@ def register_parquet(
774778
self,
775779
name: str,
776780
path: str | pathlib.Path,
777-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
781+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
778782
parquet_pruning: bool = True,
779783
file_extension: str = ".parquet",
780784
skip_metadata: bool = True,
@@ -802,6 +806,7 @@ def register_parquet(
802806
"""
803807
if table_partition_cols is None:
804808
table_partition_cols = []
809+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
805810
self.ctx.register_parquet(
806811
name,
807812
str(path),
@@ -865,7 +870,7 @@ def register_json(
865870
schema: pa.Schema | None = None,
866871
schema_infer_max_records: int = 1000,
867872
file_extension: str = ".json",
868-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
873+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
869874
file_compression_type: str | None = None,
870875
) -> None:
871876
"""Register a JSON file as a table.
@@ -886,6 +891,7 @@ def register_json(
886891
"""
887892
if table_partition_cols is None:
888893
table_partition_cols = []
894+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
889895
self.ctx.register_json(
890896
name,
891897
str(path),
@@ -902,7 +908,7 @@ def register_avro(
902908
path: str | pathlib.Path,
903909
schema: pa.Schema | None = None,
904910
file_extension: str = ".avro",
905-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
911+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
906912
) -> None:
907913
"""Register an Avro file as a table.
908914
@@ -918,6 +924,7 @@ def register_avro(
918924
"""
919925
if table_partition_cols is None:
920926
table_partition_cols = []
927+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
921928
self.ctx.register_avro(
922929
name, str(path), schema, file_extension, table_partition_cols
923930
)
@@ -977,7 +984,7 @@ def read_json(
977984
schema: pa.Schema | None = None,
978985
schema_infer_max_records: int = 1000,
979986
file_extension: str = ".json",
980-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
987+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
981988
file_compression_type: str | None = None,
982989
) -> DataFrame:
983990
"""Read a line-delimited JSON data source.
@@ -997,6 +1004,7 @@ def read_json(
9971004
"""
9981005
if table_partition_cols is None:
9991006
table_partition_cols = []
1007+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10001008
return DataFrame(
10011009
self.ctx.read_json(
10021010
str(path),
@@ -1016,7 +1024,7 @@ def read_csv(
10161024
delimiter: str = ",",
10171025
schema_infer_max_records: int = 1000,
10181026
file_extension: str = ".csv",
1019-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
1027+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
10201028
file_compression_type: str | None = None,
10211029
) -> DataFrame:
10221030
"""Read a CSV data source.
@@ -1041,6 +1049,7 @@ def read_csv(
10411049
"""
10421050
if table_partition_cols is None:
10431051
table_partition_cols = []
1052+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10441053

10451054
path = [str(p) for p in path] if isinstance(path, list) else str(path)
10461055

@@ -1060,7 +1069,7 @@ def read_csv(
10601069
def read_parquet(
10611070
self,
10621071
path: str | pathlib.Path,
1063-
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
1072+
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
10641073
parquet_pruning: bool = True,
10651074
file_extension: str = ".parquet",
10661075
skip_metadata: bool = True,
@@ -1089,6 +1098,7 @@ def read_parquet(
10891098
"""
10901099
if table_partition_cols is None:
10911100
table_partition_cols = []
1101+
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
10921102
file_sort_order = (
10931103
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
10941104
if file_sort_order is not None
@@ -1142,3 +1152,35 @@ def read_table(self, table: Table) -> DataFrame:
11421152
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11431153
"""Execute the ``plan`` and return the results."""
11441154
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
1155+
1156+
@staticmethod
1157+
def _convert_table_partition_cols(
1158+
table_partition_cols: list[tuple[str, str | pa.DataType]],
1159+
) -> list[tuple[str, pa.DataType]]:
1160+
warn = False
1161+
converted_table_partition_cols = []
1162+
1163+
for col, data_type in table_partition_cols:
1164+
if isinstance(data_type, str):
1165+
warn = True
1166+
if data_type == "string":
1167+
converted_data_type = pa.string()
1168+
elif data_type == "int":
1169+
converted_data_type = pa.int32()
1170+
else:
1171+
raise ValueError(
1172+
f"Unsupported literal data type '{data_type}' for partition column. Supported types are 'string' and 'int'"
1173+
)
1174+
else:
1175+
converted_data_type = data_type
1176+
1177+
converted_table_partition_cols.append((col, converted_data_type))
1178+
1179+
if warn:
1180+
warnings.warn(
1181+
"using literals for table_partition_cols data types is deprecated, use pyarrow types instead",
1182+
category=DeprecationWarning,
1183+
stacklevel=2,
1184+
)
1185+
1186+
return converted_table_partition_cols

python/tests/test_sql.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def test_register_parquet(ctx, tmp_path):
157157
assert result.to_pydict() == {"cnt": [100]}
158158

159159

160-
@pytest.mark.parametrize("path_to_str", [True, False])
161-
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
160+
@pytest.mark.parametrize("path_to_str,legacy_data_type", [(True, False), (False, False), (False, True)] )
161+
def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type):
162162
dir_root = tmp_path / "dataset_parquet_partitioned"
163163
dir_root.mkdir(exist_ok=False)
164164
(dir_root / "grp=a").mkdir(exist_ok=False)
@@ -177,10 +177,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
177177

178178
dir_root = str(dir_root) if path_to_str else dir_root
179179

180+
partition_data_type = 'string' if legacy_data_type else pa.string()
181+
180182
ctx.register_parquet(
181183
"datapp",
182184
dir_root,
183-
table_partition_cols=[("grp", pa.string())],
185+
table_partition_cols=[("grp", partition_data_type)],
184186
parquet_pruning=True,
185187
file_extension=".parquet",
186188
)

0 commit comments

Comments
 (0)