Skip to content

Commit 191519b

Browse files
committed
impl
impl
1 parent d6ef9bc commit 191519b

File tree

3 files changed

+66
-41
lines changed

3 files changed

+66
-41
lines changed

python/datafusion/context.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def register_listing_table(
535535
self,
536536
name: str,
537537
path: str | pathlib.Path,
538-
table_partition_cols: list[tuple[str, str]] | None = None,
538+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
539539
file_extension: str = ".parquet",
540540
schema: pa.Schema | None = None,
541541
file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -774,7 +774,7 @@ def register_parquet(
774774
self,
775775
name: str,
776776
path: str | pathlib.Path,
777-
table_partition_cols: list[tuple[str, str]] | None = None,
777+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
778778
parquet_pruning: bool = True,
779779
file_extension: str = ".parquet",
780780
skip_metadata: bool = True,
@@ -865,7 +865,7 @@ def register_json(
865865
schema: pa.Schema | None = None,
866866
schema_infer_max_records: int = 1000,
867867
file_extension: str = ".json",
868-
table_partition_cols: list[tuple[str, str]] | None = None,
868+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
869869
file_compression_type: str | None = None,
870870
) -> None:
871871
"""Register a JSON file as a table.
@@ -902,7 +902,7 @@ def register_avro(
902902
path: str | pathlib.Path,
903903
schema: pa.Schema | None = None,
904904
file_extension: str = ".avro",
905-
table_partition_cols: list[tuple[str, str]] | None = None,
905+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
906906
) -> None:
907907
"""Register an Avro file as a table.
908908
@@ -977,7 +977,7 @@ def read_json(
977977
schema: pa.Schema | None = None,
978978
schema_infer_max_records: int = 1000,
979979
file_extension: str = ".json",
980-
table_partition_cols: list[tuple[str, str]] | None = None,
980+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
981981
file_compression_type: str | None = None,
982982
) -> DataFrame:
983983
"""Read a line-delimited JSON data source.
@@ -1016,7 +1016,7 @@ def read_csv(
10161016
delimiter: str = ",",
10171017
schema_infer_max_records: int = 1000,
10181018
file_extension: str = ".csv",
1019-
table_partition_cols: list[tuple[str, str]] | None = None,
1019+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
10201020
file_compression_type: str | None = None,
10211021
) -> DataFrame:
10221022
"""Read a CSV data source.
@@ -1060,7 +1060,7 @@ def read_csv(
10601060
def read_parquet(
10611061
self,
10621062
path: str | pathlib.Path,
1063-
table_partition_cols: list[tuple[str, str]] | None = None,
1063+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
10641064
parquet_pruning: bool = True,
10651065
file_extension: str = ".parquet",
10661066
skip_metadata: bool = True,

python/datafusion/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
def read_parquet(
3636
path: str | pathlib.Path,
37-
table_partition_cols: list[tuple[str, str]] | None = None,
37+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
3838
parquet_pruning: bool = True,
3939
file_extension: str = ".parquet",
4040
skip_metadata: bool = True,
@@ -83,7 +83,7 @@ def read_json(
8383
schema: pa.Schema | None = None,
8484
schema_infer_max_records: int = 1000,
8585
file_extension: str = ".json",
86-
table_partition_cols: list[tuple[str, str]] | None = None,
86+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
8787
file_compression_type: str | None = None,
8888
) -> DataFrame:
8989
"""Read a line-delimited JSON data source.
@@ -124,7 +124,7 @@ def read_csv(
124124
delimiter: str = ",",
125125
schema_infer_max_records: int = 1000,
126126
file_extension: str = ".csv",
127-
table_partition_cols: list[tuple[str, str]] | None = None,
127+
table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
128128
file_compression_type: str | None = None,
129129
) -> DataFrame:
130130
"""Read a CSV data source.

src/context.rs

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -353,15 +353,20 @@ impl PySessionContext {
353353
&mut self,
354354
name: &str,
355355
path: &str,
356-
table_partition_cols: Vec<(String, String)>,
356+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
357357
file_extension: &str,
358358
schema: Option<PyArrowType<Schema>>,
359359
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
360360
py: Python,
361361
) -> PyDataFusionResult<()> {
362362
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
363363
.with_file_extension(file_extension)
364-
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
364+
.with_table_partition_cols(
365+
table_partition_cols
366+
.into_iter()
367+
.map(|(name, ty)| (name, ty.0))
368+
.collect::<Vec<(String, DataType)>>()
369+
)
365370
.with_file_sort_order(
366371
file_sort_order
367372
.unwrap_or_default()
@@ -629,7 +634,7 @@ impl PySessionContext {
629634
&mut self,
630635
name: &str,
631636
path: &str,
632-
table_partition_cols: Vec<(String, String)>,
637+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
633638
parquet_pruning: bool,
634639
file_extension: &str,
635640
skip_metadata: bool,
@@ -638,7 +643,12 @@ impl PySessionContext {
638643
py: Python,
639644
) -> PyDataFusionResult<()> {
640645
let mut options = ParquetReadOptions::default()
641-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
646+
.table_partition_cols(
647+
table_partition_cols
648+
.into_iter()
649+
.map(|(name, ty)| (name, ty.0))
650+
.collect::<Vec<(String, DataType)>>()
651+
)
642652
.parquet_pruning(parquet_pruning)
643653
.skip_metadata(skip_metadata);
644654
options.file_extension = file_extension;
@@ -718,7 +728,7 @@ impl PySessionContext {
718728
schema: Option<PyArrowType<Schema>>,
719729
schema_infer_max_records: usize,
720730
file_extension: &str,
721-
table_partition_cols: Vec<(String, String)>,
731+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
722732
file_compression_type: Option<String>,
723733
py: Python,
724734
) -> PyDataFusionResult<()> {
@@ -728,7 +738,12 @@ impl PySessionContext {
728738

729739
let mut options = NdJsonReadOptions::default()
730740
.file_compression_type(parse_file_compression_type(file_compression_type)?)
731-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
741+
.table_partition_cols(
742+
table_partition_cols
743+
.into_iter()
744+
.map(|(name, ty)| (name, ty.0))
745+
.collect::<Vec<(String, DataType)>>()
746+
);
732747
options.schema_infer_max_records = schema_infer_max_records;
733748
options.file_extension = file_extension;
734749
options.schema = schema.as_ref().map(|x| &x.0);
@@ -751,15 +766,20 @@ impl PySessionContext {
751766
path: PathBuf,
752767
schema: Option<PyArrowType<Schema>>,
753768
file_extension: &str,
754-
table_partition_cols: Vec<(String, String)>,
769+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
755770
py: Python,
756771
) -> PyDataFusionResult<()> {
757772
let path = path
758773
.to_str()
759774
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
760775

761776
let mut options = AvroReadOptions::default()
762-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
777+
.table_partition_cols(
778+
table_partition_cols
779+
.into_iter()
780+
.map(|(name, ty)| (name, ty.0))
781+
.collect::<Vec<(String, DataType)>>()
782+
);
763783
options.file_extension = file_extension;
764784
options.schema = schema.as_ref().map(|x| &x.0);
765785

@@ -850,15 +870,20 @@ impl PySessionContext {
850870
schema: Option<PyArrowType<Schema>>,
851871
schema_infer_max_records: usize,
852872
file_extension: &str,
853-
table_partition_cols: Vec<(String, String)>,
873+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
854874
file_compression_type: Option<String>,
855875
py: Python,
856876
) -> PyDataFusionResult<PyDataFrame> {
857877
let path = path
858878
.to_str()
859879
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
860880
let mut options = NdJsonReadOptions::default()
861-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
881+
.table_partition_cols(
882+
table_partition_cols
883+
.into_iter()
884+
.map(|(name, ty)| (name, ty.0))
885+
.collect::<Vec<(String, DataType)>>()
886+
)
862887
.file_compression_type(parse_file_compression_type(file_compression_type)?);
863888
options.schema_infer_max_records = schema_infer_max_records;
864889
options.file_extension = file_extension;
@@ -891,7 +916,7 @@ impl PySessionContext {
891916
delimiter: &str,
892917
schema_infer_max_records: usize,
893918
file_extension: &str,
894-
table_partition_cols: Vec<(String, String)>,
919+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
895920
file_compression_type: Option<String>,
896921
py: Python,
897922
) -> PyDataFusionResult<PyDataFrame> {
@@ -907,7 +932,12 @@ impl PySessionContext {
907932
.delimiter(delimiter[0])
908933
.schema_infer_max_records(schema_infer_max_records)
909934
.file_extension(file_extension)
910-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
935+
.table_partition_cols(
936+
table_partition_cols
937+
.into_iter()
938+
.map(|(name, ty)| (name, ty.0))
939+
.collect::<Vec<(String, DataType)>>()
940+
)
911941
.file_compression_type(parse_file_compression_type(file_compression_type)?);
912942
options.schema = schema.as_ref().map(|x| &x.0);
913943

@@ -937,7 +967,7 @@ impl PySessionContext {
937967
pub fn read_parquet(
938968
&self,
939969
path: &str,
940-
table_partition_cols: Vec<(String, String)>,
970+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
941971
parquet_pruning: bool,
942972
file_extension: &str,
943973
skip_metadata: bool,
@@ -946,7 +976,12 @@ impl PySessionContext {
946976
py: Python,
947977
) -> PyDataFusionResult<PyDataFrame> {
948978
let mut options = ParquetReadOptions::default()
949-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
979+
.table_partition_cols(
980+
table_partition_cols
981+
.into_iter()
982+
.map(|(name, ty)| (name, ty.0))
983+
.collect::<Vec<(String, DataType)>>()
984+
)
950985
.parquet_pruning(parquet_pruning)
951986
.skip_metadata(skip_metadata);
952987
options.file_extension = file_extension;
@@ -968,12 +1003,17 @@ impl PySessionContext {
9681003
&self,
9691004
path: &str,
9701005
schema: Option<PyArrowType<Schema>>,
971-
table_partition_cols: Vec<(String, String)>,
1006+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
9721007
file_extension: &str,
9731008
py: Python,
9741009
) -> PyDataFusionResult<PyDataFrame> {
9751010
let mut options = AvroReadOptions::default()
976-
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
1011+
.table_partition_cols(
1012+
table_partition_cols
1013+
.into_iter()
1014+
.map(|(name, ty)| (name, ty.0))
1015+
.collect::<Vec<(String, DataType)>>()
1016+
);
9771017
options.file_extension = file_extension;
9781018
let df = if let Some(schema) = schema {
9791019
options.schema = Some(&schema.0);
@@ -1072,21 +1112,6 @@ impl PySessionContext {
10721112
}
10731113
}
10741114

1075-
pub fn convert_table_partition_cols(
1076-
table_partition_cols: Vec<(String, String)>,
1077-
) -> PyDataFusionResult<Vec<(String, DataType)>> {
1078-
table_partition_cols
1079-
.into_iter()
1080-
.map(|(name, ty)| match ty.as_str() {
1081-
"string" => Ok((name, DataType::Utf8)),
1082-
"int" => Ok((name, DataType::Int32)),
1083-
_ => Err(crate::errors::PyDataFusionError::Common(format!(
1084-
"Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
1085-
))),
1086-
})
1087-
.collect::<Result<Vec<_>, _>>()
1088-
}
1089-
10901115
pub fn parse_file_compression_type(
10911116
file_compression_type: Option<String>,
10921117
) -> Result<FileCompressionType, PyErr> {

0 commit comments

Comments
 (0)