|
22 | 22 | import warnings |
23 | 23 | from typing import TYPE_CHECKING, Any, Protocol |
24 | 24 |
|
25 | | -import pyarrow as pa |
| 25 | +from datafusion.types import ensure_pyarrow_type |
| 26 | +from datafusion.common import DataTypeMap |
26 | 27 |
|
27 | 28 | try: |
28 | 29 | from warnings import deprecated # Python 3.13+ |
|
45 | 46 |
|
46 | 47 | import pandas as pd |
47 | 48 | import polars as pl |
| 49 | + import pyarrow as pa |
48 | 50 |
|
49 | 51 | from datafusion.plan import ExecutionPlan, LogicalPlan |
50 | 52 |
|
@@ -550,7 +552,7 @@ def register_listing_table( |
550 | 552 | self, |
551 | 553 | name: str, |
552 | 554 | path: str | pathlib.Path, |
553 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 555 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
554 | 556 | file_extension: str = ".parquet", |
555 | 557 | schema: pa.Schema | None = None, |
556 | 558 | file_sort_order: list[list[Expr | SortExpr]] | None = None, |
@@ -803,7 +805,7 @@ def register_parquet( |
803 | 805 | self, |
804 | 806 | name: str, |
805 | 807 | path: str | pathlib.Path, |
806 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 808 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
807 | 809 | parquet_pruning: bool = True, |
808 | 810 | file_extension: str = ".parquet", |
809 | 811 | skip_metadata: bool = True, |
@@ -895,7 +897,7 @@ def register_json( |
895 | 897 | schema: pa.Schema | None = None, |
896 | 898 | schema_infer_max_records: int = 1000, |
897 | 899 | file_extension: str = ".json", |
898 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 900 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
899 | 901 | file_compression_type: str | None = None, |
900 | 902 | ) -> None: |
901 | 903 | """Register a JSON file as a table. |
@@ -933,7 +935,7 @@ def register_avro( |
933 | 935 | path: str | pathlib.Path, |
934 | 936 | schema: pa.Schema | None = None, |
935 | 937 | file_extension: str = ".avro", |
936 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 938 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
937 | 939 | ) -> None: |
938 | 940 | """Register an Avro file as a table. |
939 | 941 |
|
@@ -1009,7 +1011,7 @@ def read_json( |
1009 | 1011 | schema: pa.Schema | None = None, |
1010 | 1012 | schema_infer_max_records: int = 1000, |
1011 | 1013 | file_extension: str = ".json", |
1012 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 1014 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
1013 | 1015 | file_compression_type: str | None = None, |
1014 | 1016 | ) -> DataFrame: |
1015 | 1017 | """Read a line-delimited JSON data source. |
@@ -1049,7 +1051,7 @@ def read_csv( |
1049 | 1051 | delimiter: str = ",", |
1050 | 1052 | schema_infer_max_records: int = 1000, |
1051 | 1053 | file_extension: str = ".csv", |
1052 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 1054 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
1053 | 1055 | file_compression_type: str | None = None, |
1054 | 1056 | ) -> DataFrame: |
1055 | 1057 | """Read a CSV data source. |
@@ -1094,7 +1096,7 @@ def read_csv( |
1094 | 1096 | def read_parquet( |
1095 | 1097 | self, |
1096 | 1098 | path: str | pathlib.Path, |
1097 | | - table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 1099 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
1098 | 1100 | parquet_pruning: bool = True, |
1099 | 1101 | file_extension: str = ".parquet", |
1100 | 1102 | skip_metadata: bool = True, |
@@ -1145,7 +1147,7 @@ def read_avro( |
1145 | 1147 | self, |
1146 | 1148 | path: str | pathlib.Path, |
1147 | 1149 | schema: pa.Schema | None = None, |
1148 | | - file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, |
| 1150 | + file_partition_cols: list[tuple[str, str | DataTypeMap | Any]] | None = None, |
1149 | 1151 | file_extension: str = ".avro", |
1150 | 1152 | ) -> DataFrame: |
1151 | 1153 | """Create a :py:class:`DataFrame` for reading Avro data source. |
@@ -1181,26 +1183,27 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: |
1181 | 1183 |
|
1182 | 1184 | @staticmethod |
1183 | 1185 | def _convert_table_partition_cols( |
1184 | | - table_partition_cols: list[tuple[str, str | pa.DataType]], |
1185 | | - ) -> list[tuple[str, pa.DataType]]: |
| 1186 | + table_partition_cols: list[tuple[str, str | DataTypeMap | Any]], |
| 1187 | + ) -> list[tuple[str, Any]]: |
1186 | 1188 | warn = False |
1187 | 1189 | converted_table_partition_cols = [] |
1188 | 1190 |
|
1189 | 1191 | for col, data_type in table_partition_cols: |
1190 | 1192 | if isinstance(data_type, str): |
1191 | 1193 | warn = True |
1192 | 1194 | if data_type == "string": |
1193 | | - converted_data_type = pa.string() |
| 1195 | + mapped = DataTypeMap.py_map_from_arrow_type_str("utf8") |
1194 | 1196 | elif data_type == "int": |
1195 | | - converted_data_type = pa.int32() |
| 1197 | + mapped = DataTypeMap.py_map_from_arrow_type_str("int32") |
1196 | 1198 | else: |
1197 | 1199 | message = ( |
1198 | 1200 | f"Unsupported literal data type '{data_type}' for partition " |
1199 | 1201 | "column. Supported types are 'string' and 'int'" |
1200 | 1202 | ) |
1201 | 1203 | raise ValueError(message) |
| 1204 | + converted_data_type = ensure_pyarrow_type(mapped) |
1202 | 1205 | else: |
1203 | | - converted_data_type = data_type |
| 1206 | + converted_data_type = ensure_pyarrow_type(data_type) |
1204 | 1207 |
|
1205 | 1208 | converted_table_partition_cols.append((col, converted_data_type)) |
1206 | 1209 |
|
|
0 commit comments