2222import warnings
2323from typing import TYPE_CHECKING , Any , Protocol
2424
25- import pyarrow as pa
25+ from datafusion .common import DataTypeMap
26+ from datafusion .types import ensure_pyarrow_type
2627
2728try :
2829 from warnings import deprecated # Python 3.13+
4546
4647 import pandas as pd
4748 import polars as pl
49+ import pyarrow as pa
4850
4951 from datafusion .plan import ExecutionPlan , LogicalPlan
5052
@@ -550,7 +552,7 @@ def register_listing_table(
550552 self ,
551553 name : str ,
552554 path : str | pathlib .Path ,
553- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
555+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
554556 file_extension : str = ".parquet" ,
555557 schema : pa .Schema | None = None ,
556558 file_sort_order : list [list [Expr | SortExpr ]] | None = None ,
@@ -803,7 +805,7 @@ def register_parquet(
803805 self ,
804806 name : str ,
805807 path : str | pathlib .Path ,
806- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
808+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
807809 parquet_pruning : bool = True ,
808810 file_extension : str = ".parquet" ,
809811 skip_metadata : bool = True ,
@@ -895,7 +897,7 @@ def register_json(
895897 schema : pa .Schema | None = None ,
896898 schema_infer_max_records : int = 1000 ,
897899 file_extension : str = ".json" ,
898- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
900+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
899901 file_compression_type : str | None = None ,
900902 ) -> None :
901903 """Register a JSON file as a table.
@@ -933,7 +935,7 @@ def register_avro(
933935 path : str | pathlib .Path ,
934936 schema : pa .Schema | None = None ,
935937 file_extension : str = ".avro" ,
936- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
938+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
937939 ) -> None :
938940 """Register an Avro file as a table.
939941
@@ -954,12 +956,16 @@ def register_avro(
954956 name , str (path ), schema , file_extension , table_partition_cols
955957 )
956958
957- def register_dataset (self , name : str , dataset : pa .dataset .Dataset ) -> None :
958- """Register a :py:class:`pa.dataset.Dataset` as a table.
959+ def register_dataset (self , name : str , dataset : object ) -> None :
960+ """Register any ``__arrow_c_stream__`` source as a table.
961+
962+ Any Python object implementing the Arrow ``__arrow_c_stream__`` protocol
963+ can be registered, including objects from libraries such as nanoarrow,
964+ Polars, DuckDB, or :py:mod:`pyarrow.dataset`.
959965
960966 Args:
961967 name: Name of the table to register.
962- dataset: PyArrow dataset .
968+ dataset: Object exposing ``__arrow_c_stream__`` .
963969 """
964970 self .ctx .register_dataset (name , dataset )
965971
@@ -1009,7 +1015,7 @@ def read_json(
10091015 schema : pa .Schema | None = None ,
10101016 schema_infer_max_records : int = 1000 ,
10111017 file_extension : str = ".json" ,
1012- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1018+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
10131019 file_compression_type : str | None = None ,
10141020 ) -> DataFrame :
10151021 """Read a line-delimited JSON data source.
@@ -1049,7 +1055,7 @@ def read_csv(
10491055 delimiter : str = "," ,
10501056 schema_infer_max_records : int = 1000 ,
10511057 file_extension : str = ".csv" ,
1052- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1058+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
10531059 file_compression_type : str | None = None ,
10541060 ) -> DataFrame :
10551061 """Read a CSV data source.
@@ -1094,7 +1100,7 @@ def read_csv(
10941100 def read_parquet (
10951101 self ,
10961102 path : str | pathlib .Path ,
1097- table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1103+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
10981104 parquet_pruning : bool = True ,
10991105 file_extension : str = ".parquet" ,
11001106 skip_metadata : bool = True ,
@@ -1145,7 +1151,7 @@ def read_avro(
11451151 self ,
11461152 path : str | pathlib .Path ,
11471153 schema : pa .Schema | None = None ,
1148- file_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
1154+ file_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
11491155 file_extension : str = ".avro" ,
11501156 ) -> DataFrame :
11511157 """Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1181,26 +1187,27 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11811187
11821188 @staticmethod
11831189 def _convert_table_partition_cols (
1184- table_partition_cols : list [tuple [str , str | pa . DataType ]],
1185- ) -> list [tuple [str , pa . DataType ]]:
1190+ table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]],
1191+ ) -> list [tuple [str , Any ]]:
11861192 warn = False
11871193 converted_table_partition_cols = []
11881194
11891195 for col , data_type in table_partition_cols :
11901196 if isinstance (data_type , str ):
11911197 warn = True
11921198 if data_type == "string" :
1193- converted_data_type = pa . string ( )
1199+ mapped = DataTypeMap . py_map_from_arrow_type_str ( "utf8" )
11941200 elif data_type == "int" :
1195- converted_data_type = pa . int32 ( )
1201+ mapped = DataTypeMap . py_map_from_arrow_type_str ( "int32" )
11961202 else :
11971203 message = (
11981204 f"Unsupported literal data type '{ data_type } ' for partition "
11991205 "column. Supported types are 'string' and 'int'"
12001206 )
12011207 raise ValueError (message )
1208+ converted_data_type = ensure_pyarrow_type (mapped )
12021209 else :
1203- converted_data_type = data_type
1210+ converted_data_type = ensure_pyarrow_type ( data_type )
12041211
12051212 converted_table_partition_cols .append ((col , converted_data_type ))
12061213
0 commit comments