2222import warnings
2323from typing import TYPE_CHECKING , Any , Protocol
2424
25- from datafusion .common import DataTypeMap
26- from datafusion .types import ensure_pyarrow_type
25+ import pyarrow as pa
2726
2827try :
2928 from warnings import deprecated # Python 3.13+
4645
4746 import pandas as pd
4847 import polars as pl
49- import pyarrow as pa
5048
5149 from datafusion .plan import ExecutionPlan , LogicalPlan
5250
@@ -552,7 +550,7 @@ def register_listing_table(
552550 self ,
553551 name : str ,
554552 path : str | pathlib .Path ,
555- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
553+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
556554 file_extension : str = ".parquet" ,
557555 schema : pa .Schema | None = None ,
558556 file_sort_order : list [list [Expr | SortExpr ]] | None = None ,
@@ -805,7 +803,7 @@ def register_parquet(
805803 self ,
806804 name : str ,
807805 path : str | pathlib .Path ,
808- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
806+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
809807 parquet_pruning : bool = True ,
810808 file_extension : str = ".parquet" ,
811809 skip_metadata : bool = True ,
@@ -897,7 +895,7 @@ def register_json(
897895 schema : pa .Schema | None = None ,
898896 schema_infer_max_records : int = 1000 ,
899897 file_extension : str = ".json" ,
900- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
898+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
901899 file_compression_type : str | None = None ,
902900 ) -> None :
903901 """Register a JSON file as a table.
@@ -935,7 +933,7 @@ def register_avro(
935933 path : str | pathlib .Path ,
936934 schema : pa .Schema | None = None ,
937935 file_extension : str = ".avro" ,
938- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
936+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
939937 ) -> None :
940938 """Register an Avro file as a table.
941939
@@ -956,16 +954,12 @@ def register_avro(
956954 name , str (path ), schema , file_extension , table_partition_cols
957955 )
958956
959- def register_dataset (self , name : str , dataset : object ) -> None :
960- """Register any ``__arrow_c_stream__`` source as a table.
961-
962- Any Python object implementing the Arrow ``__arrow_c_stream__`` protocol
963- can be registered, including objects from libraries such as nanoarrow,
964- Polars, DuckDB, or :py:mod:`pyarrow.dataset`.
957+ def register_dataset (self , name : str , dataset : pa .dataset .Dataset ) -> None :
958+ """Register a :py:class:`pa.dataset.Dataset` as a table.
965959
966960 Args:
967961 name: Name of the table to register.
968- dataset: Object exposing ``__arrow_c_stream__`` .
962+ dataset: PyArrow dataset .
969963 """
970964 self .ctx .register_dataset (name , dataset )
971965
@@ -1015,7 +1009,7 @@ def read_json(
10151009 schema : pa .Schema | None = None ,
10161010 schema_infer_max_records : int = 1000 ,
10171011 file_extension : str = ".json" ,
1018- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
1012+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
10191013 file_compression_type : str | None = None ,
10201014 ) -> DataFrame :
10211015 """Read a line-delimited JSON data source.
@@ -1055,7 +1049,7 @@ def read_csv(
10551049 delimiter : str = "," ,
10561050 schema_infer_max_records : int = 1000 ,
10571051 file_extension : str = ".csv" ,
1058- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
1052+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
10591053 file_compression_type : str | None = None ,
10601054 ) -> DataFrame :
10611055 """Read a CSV data source.
@@ -1100,7 +1094,7 @@ def read_csv(
11001094 def read_parquet (
11011095 self ,
11021096 path : str | pathlib .Path ,
1103- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
1097+ table_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
11041098 parquet_pruning : bool = True ,
11051099 file_extension : str = ".parquet" ,
11061100 skip_metadata : bool = True ,
@@ -1151,7 +1145,7 @@ def read_avro(
11511145 self ,
11521146 path : str | pathlib .Path ,
11531147 schema : pa .Schema | None = None ,
1154- file_partition_cols : list [tuple [str , str | DataTypeMap | Any ]] | None = None ,
1148+ file_partition_cols : list [tuple [str , str | pa . DataType ]] | None = None ,
11551149 file_extension : str = ".avro" ,
11561150 ) -> DataFrame :
11571151 """Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1187,27 +1181,26 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11871181
11881182 @staticmethod
11891183 def _convert_table_partition_cols (
1190- table_partition_cols : list [tuple [str , str | DataTypeMap | Any ]],
1191- ) -> list [tuple [str , Any ]]:
1184+ table_partition_cols : list [tuple [str , str | pa . DataType ]],
1185+ ) -> list [tuple [str , pa . DataType ]]:
11921186 warn = False
11931187 converted_table_partition_cols = []
11941188
11951189 for col , data_type in table_partition_cols :
11961190 if isinstance (data_type , str ):
11971191 warn = True
11981192 if data_type == "string" :
1199- mapped = DataTypeMap . py_map_from_arrow_type_str ( "utf8" )
1193+ converted_data_type = pa . string ( )
12001194 elif data_type == "int" :
1201- mapped = DataTypeMap . py_map_from_arrow_type_str ( " int32" )
1195+ converted_data_type = pa . int32 ( )
12021196 else :
12031197 message = (
12041198 f"Unsupported literal data type '{ data_type } ' for partition "
12051199 "column. Supported types are 'string' and 'int'"
12061200 )
12071201 raise ValueError (message )
1208- converted_data_type = ensure_pyarrow_type (mapped )
12091202 else :
1210- converted_data_type = ensure_pyarrow_type ( data_type )
1203+ converted_data_type = data_type
12111204
12121205 converted_table_partition_cols .append ((col , converted_data_type ))
12131206
0 commit comments