1919
2020from __future__ import annotations
2121
22+ import warnings
2223from typing import TYPE_CHECKING , Any , Protocol
2324
25+ import pyarrow as pa
26+
2427try :
2528 from warnings import deprecated # Python 3.13+
2629except ImportError :
@@ -535,7 +538,7 @@ def register_listing_table(
535538 self ,
536539 name : str ,
537540 path : str | pathlib .Path ,
538- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
541+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
539542 file_extension : str = ".parquet" ,
540543 schema : pa .Schema | None = None ,
541544 file_sort_order : list [list [Expr | SortExpr ]] | None = None ,
@@ -556,6 +559,7 @@ def register_listing_table(
556559 """
557560 if table_partition_cols is None :
558561 table_partition_cols = []
562+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
559563 file_sort_order_raw = (
560564 [sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
561565 if file_sort_order is not None
@@ -774,7 +778,7 @@ def register_parquet(
774778 self ,
775779 name : str ,
776780 path : str | pathlib .Path ,
777- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
781+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
778782 parquet_pruning : bool = True ,
779783 file_extension : str = ".parquet" ,
780784 skip_metadata : bool = True ,
@@ -802,6 +806,7 @@ def register_parquet(
802806 """
803807 if table_partition_cols is None :
804808 table_partition_cols = []
809+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
805810 self .ctx .register_parquet (
806811 name ,
807812 str (path ),
@@ -865,7 +870,7 @@ def register_json(
865870 schema : pa .Schema | None = None ,
866871 schema_infer_max_records : int = 1000 ,
867872 file_extension : str = ".json" ,
868- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
873+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
869874 file_compression_type : str | None = None ,
870875 ) -> None :
871876 """Register a JSON file as a table.
@@ -886,6 +891,7 @@ def register_json(
886891 """
887892 if table_partition_cols is None :
888893 table_partition_cols = []
894+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
889895 self .ctx .register_json (
890896 name ,
891897 str (path ),
@@ -902,7 +908,7 @@ def register_avro(
902908 path : str | pathlib .Path ,
903909 schema : pa .Schema | None = None ,
904910 file_extension : str = ".avro" ,
905- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
911+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
906912 ) -> None :
907913 """Register an Avro file as a table.
908914
@@ -918,6 +924,7 @@ def register_avro(
918924 """
919925 if table_partition_cols is None :
920926 table_partition_cols = []
927+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
921928 self .ctx .register_avro (
922929 name , str (path ), schema , file_extension , table_partition_cols
923930 )
@@ -977,7 +984,7 @@ def read_json(
977984 schema : pa .Schema | None = None ,
978985 schema_infer_max_records : int = 1000 ,
979986 file_extension : str = ".json" ,
980- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
987+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
981988 file_compression_type : str | None = None ,
982989 ) -> DataFrame :
983990 """Read a line-delimited JSON data source.
@@ -997,6 +1004,7 @@ def read_json(
9971004 """
9981005 if table_partition_cols is None :
9991006 table_partition_cols = []
1007+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10001008 return DataFrame (
10011009 self .ctx .read_json (
10021010 str (path ),
@@ -1016,7 +1024,7 @@ def read_csv(
10161024 delimiter : str = "," ,
10171025 schema_infer_max_records : int = 1000 ,
10181026 file_extension : str = ".csv" ,
1019- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
1027+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
10201028 file_compression_type : str | None = None ,
10211029 ) -> DataFrame :
10221030 """Read a CSV data source.
@@ -1041,6 +1049,7 @@ def read_csv(
10411049 """
10421050 if table_partition_cols is None :
10431051 table_partition_cols = []
1052+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10441053
10451054 path = [str (p ) for p in path ] if isinstance (path , list ) else str (path )
10461055
@@ -1060,7 +1069,7 @@ def read_csv(
10601069 def read_parquet (
10611070 self ,
10621071 path : str | pathlib .Path ,
1063- table_partition_cols : list [tuple [str , pa .DataType ]] | None = None ,
1072+ table_partition_cols : list [tuple [str , str | pa .DataType ]] | None = None ,
10641073 parquet_pruning : bool = True ,
10651074 file_extension : str = ".parquet" ,
10661075 skip_metadata : bool = True ,
@@ -1089,6 +1098,7 @@ def read_parquet(
10891098 """
10901099 if table_partition_cols is None :
10911100 table_partition_cols = []
1101+ table_partition_cols = self ._convert_table_partition_cols (table_partition_cols )
10921102 file_sort_order = (
10931103 [sort_list_to_raw_sort_list (f ) for f in file_sort_order ]
10941104 if file_sort_order is not None
@@ -1142,3 +1152,35 @@ def read_table(self, table: Table) -> DataFrame:
11421152 def execute (self , plan : ExecutionPlan , partitions : int ) -> RecordBatchStream :
11431153 """Execute the ``plan`` and return the results."""
11441154 return RecordBatchStream (self .ctx .execute (plan ._raw_plan , partitions ))
1155+
1156+ @staticmethod
1157+ def _convert_table_partition_cols (
1158+ table_partition_cols : list [tuple [str , str | pa .DataType ]],
1159+ ) -> list [tuple [str , pa .DataType ]]:
1160+ warn = False
1161+ converted_table_partition_cols = []
1162+
1163+ for col , data_type in table_partition_cols :
1164+ if isinstance (data_type , str ):
1165+ warn = True
1166+ if data_type == "string" :
1167+ converted_data_type = pa .string ()
1168+ elif data_type == "int" :
1169+ converted_data_type = pa .int32 ()
1170+ else :
1171+ raise ValueError (
1172+ f"Unsupported literal data type '{ data_type } ' for partition column. Supported types are 'string' and 'int'"
1173+ )
1174+ else :
1175+ converted_data_type = data_type
1176+
1177+ converted_table_partition_cols .append ((col , converted_data_type ))
1178+
1179+ if warn :
1180+ warnings .warn (
1181+ "using literals for table_partition_cols data types is deprecated, use pyarrow types instead" ,
1182+ category = DeprecationWarning ,
1183+ stacklevel = 2 ,
1184+ )
1185+
1186+ return converted_table_partition_cols
0 commit comments