Skip to content

Commit 39e190f

Browse files
committed
refactor(context): deduplicate register/read option-building logic
Extract shared helpers (convert_partition_cols, convert_file_sort_order, build_parquet/json/avro_options, convert_csv_options), standardize path types to &str, and remove redundant intermediate variables.
1 parent 16feeb1 commit 39e190f

File tree

1 file changed

+131
-150
lines changed

1 file changed

+131
-150
lines changed

crates/core/src/context.rs

Lines changed: 131 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
// under the License.
1717

1818
use std::collections::{HashMap, HashSet};
19-
use std::path::PathBuf;
2019
use std::ptr::NonNull;
2120
use std::str::FromStr;
2221
use std::sync::Arc;
@@ -456,19 +455,8 @@ impl PySessionContext {
456455
) -> PyDataFusionResult<()> {
457456
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
458457
.with_file_extension(file_extension)
459-
.with_table_partition_cols(
460-
table_partition_cols
461-
.into_iter()
462-
.map(|(name, ty)| (name, ty.0))
463-
.collect::<Vec<(String, DataType)>>(),
464-
)
465-
.with_file_sort_order(
466-
file_sort_order
467-
.unwrap_or_default()
468-
.into_iter()
469-
.map(|e| e.into_iter().map(|f| f.into()).collect())
470-
.collect(),
471-
);
458+
.with_table_partition_cols(convert_partition_cols(table_partition_cols))
459+
.with_file_sort_order(convert_file_sort_order(file_sort_order));
472460
let table_path = ListingTableUrl::parse(path)?;
473461
let resolved_schema: SchemaRef = match schema {
474462
Some(s) => Arc::new(s.0),
@@ -831,25 +819,15 @@ impl PySessionContext {
831819
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
832820
py: Python,
833821
) -> PyDataFusionResult<()> {
834-
let mut options = ParquetReadOptions::default()
835-
.table_partition_cols(
836-
table_partition_cols
837-
.into_iter()
838-
.map(|(name, ty)| (name, ty.0))
839-
.collect::<Vec<(String, DataType)>>(),
840-
)
841-
.parquet_pruning(parquet_pruning)
842-
.skip_metadata(skip_metadata);
843-
options.file_extension = file_extension;
844-
options.schema = schema.as_ref().map(|x| &x.0);
845-
options.file_sort_order = file_sort_order
846-
.unwrap_or_default()
847-
.into_iter()
848-
.map(|e| e.into_iter().map(|f| f.into()).collect())
849-
.collect();
850-
851-
let result = self.ctx.register_parquet(name, path, options);
852-
wait_for_future(py, result)??;
822+
let options = build_parquet_options(
823+
table_partition_cols,
824+
parquet_pruning,
825+
file_extension,
826+
skip_metadata,
827+
&schema,
828+
file_sort_order,
829+
);
830+
wait_for_future(py, self.ctx.register_parquet(name, path, options))??;
853831
Ok(())
854832
}
855833

@@ -863,19 +841,17 @@ impl PySessionContext {
863841
options: Option<&PyCsvReadOptions>,
864842
py: Python,
865843
) -> PyDataFusionResult<()> {
866-
let options = options
867-
.map(|opts| opts.try_into())
868-
.transpose()?
869-
.unwrap_or_default();
844+
let options = convert_csv_options(options)?;
870845

871846
if path.is_instance_of::<PyList>() {
872847
let paths = path.extract::<Vec<String>>()?;
873-
let result = self.register_csv_from_multiple_paths(name, paths, options);
874-
wait_for_future(py, result)??;
848+
wait_for_future(
849+
py,
850+
self.register_csv_from_multiple_paths(name, paths, options),
851+
)??;
875852
} else {
876853
let path = path.extract::<String>()?;
877-
let result = self.ctx.register_csv(name, &path, options);
878-
wait_for_future(py, result)??;
854+
wait_for_future(py, self.ctx.register_csv(name, &path, options))??;
879855
}
880856

881857
Ok(())
@@ -892,33 +868,22 @@ impl PySessionContext {
892868
pub fn register_json(
893869
&self,
894870
name: &str,
895-
path: PathBuf,
871+
path: &str,
896872
schema: Option<PyArrowType<Schema>>,
897873
schema_infer_max_records: usize,
898874
file_extension: &str,
899875
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
900876
file_compression_type: Option<String>,
901877
py: Python,
902878
) -> PyDataFusionResult<()> {
903-
let path = path
904-
.to_str()
905-
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
906-
907-
let mut options = JsonReadOptions::default()
908-
.file_compression_type(parse_file_compression_type(file_compression_type)?)
909-
.table_partition_cols(
910-
table_partition_cols
911-
.into_iter()
912-
.map(|(name, ty)| (name, ty.0))
913-
.collect::<Vec<(String, DataType)>>(),
914-
);
915-
options.schema_infer_max_records = schema_infer_max_records;
916-
options.file_extension = file_extension;
917-
options.schema = schema.as_ref().map(|x| &x.0);
918-
919-
let result = self.ctx.register_json(name, path, options);
920-
wait_for_future(py, result)??;
921-
879+
let options = build_json_options(
880+
table_partition_cols,
881+
file_compression_type,
882+
schema_infer_max_records,
883+
file_extension,
884+
&schema,
885+
)?;
886+
wait_for_future(py, self.ctx.register_json(name, path, options))??;
922887
Ok(())
923888
}
924889

@@ -931,28 +896,14 @@ impl PySessionContext {
931896
pub fn register_avro(
932897
&self,
933898
name: &str,
934-
path: PathBuf,
899+
path: &str,
935900
schema: Option<PyArrowType<Schema>>,
936901
file_extension: &str,
937902
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
938903
py: Python,
939904
) -> PyDataFusionResult<()> {
940-
let path = path
941-
.to_str()
942-
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
943-
944-
let mut options = AvroReadOptions::default().table_partition_cols(
945-
table_partition_cols
946-
.into_iter()
947-
.map(|(name, ty)| (name, ty.0))
948-
.collect::<Vec<(String, DataType)>>(),
949-
);
950-
options.file_extension = file_extension;
951-
options.schema = schema.as_ref().map(|x| &x.0);
952-
953-
let result = self.ctx.register_avro(name, path, options);
954-
wait_for_future(py, result)??;
955-
905+
let options = build_avro_options(table_partition_cols, file_extension, &schema);
906+
wait_for_future(py, self.ctx.register_avro(name, path, options))??;
956907
Ok(())
957908
}
958909

@@ -1054,35 +1005,22 @@ impl PySessionContext {
10541005
#[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
10551006
pub fn read_json(
10561007
&self,
1057-
path: PathBuf,
1008+
path: &str,
10581009
schema: Option<PyArrowType<Schema>>,
10591010
schema_infer_max_records: usize,
10601011
file_extension: &str,
10611012
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
10621013
file_compression_type: Option<String>,
10631014
py: Python,
10641015
) -> PyDataFusionResult<PyDataFrame> {
1065-
let path = path
1066-
.to_str()
1067-
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
1068-
let mut options = JsonReadOptions::default()
1069-
.table_partition_cols(
1070-
table_partition_cols
1071-
.into_iter()
1072-
.map(|(name, ty)| (name, ty.0))
1073-
.collect::<Vec<(String, DataType)>>(),
1074-
)
1075-
.file_compression_type(parse_file_compression_type(file_compression_type)?);
1076-
options.schema_infer_max_records = schema_infer_max_records;
1077-
options.file_extension = file_extension;
1078-
let df = if let Some(schema) = schema {
1079-
options.schema = Some(&schema.0);
1080-
let result = self.ctx.read_json(path, options);
1081-
wait_for_future(py, result)??
1082-
} else {
1083-
let result = self.ctx.read_json(path, options);
1084-
wait_for_future(py, result)??
1085-
};
1016+
let options = build_json_options(
1017+
table_partition_cols,
1018+
file_compression_type,
1019+
schema_infer_max_records,
1020+
file_extension,
1021+
&schema,
1022+
)?;
1023+
let df = wait_for_future(py, self.ctx.read_json(path, options))??;
10861024
Ok(PyDataFrame::new(df))
10871025
}
10881026

@@ -1095,23 +1033,15 @@ impl PySessionContext {
10951033
options: Option<&PyCsvReadOptions>,
10961034
py: Python,
10971035
) -> PyDataFusionResult<PyDataFrame> {
1098-
let options = options
1099-
.map(|opts| opts.try_into())
1100-
.transpose()?
1101-
.unwrap_or_default();
1036+
let options = convert_csv_options(options)?;
11021037

1103-
if path.is_instance_of::<PyList>() {
1104-
let paths = path.extract::<Vec<String>>()?;
1105-
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1106-
let result = self.ctx.read_csv(paths, options);
1107-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1108-
Ok(df)
1038+
let paths: Vec<String> = if path.is_instance_of::<PyList>() {
1039+
path.extract()?
11091040
} else {
1110-
let path = path.extract::<String>()?;
1111-
let result = self.ctx.read_csv(path, options);
1112-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1113-
Ok(df)
1114-
}
1041+
vec![path.extract()?]
1042+
};
1043+
let df = wait_for_future(py, self.ctx.read_csv(paths, options))??;
1044+
Ok(PyDataFrame::new(df))
11151045
}
11161046

11171047
#[allow(clippy::too_many_arguments)]
@@ -1134,25 +1064,15 @@ impl PySessionContext {
11341064
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
11351065
py: Python,
11361066
) -> PyDataFusionResult<PyDataFrame> {
1137-
let mut options = ParquetReadOptions::default()
1138-
.table_partition_cols(
1139-
table_partition_cols
1140-
.into_iter()
1141-
.map(|(name, ty)| (name, ty.0))
1142-
.collect::<Vec<(String, DataType)>>(),
1143-
)
1144-
.parquet_pruning(parquet_pruning)
1145-
.skip_metadata(skip_metadata);
1146-
options.file_extension = file_extension;
1147-
options.schema = schema.as_ref().map(|x| &x.0);
1148-
options.file_sort_order = file_sort_order
1149-
.unwrap_or_default()
1150-
.into_iter()
1151-
.map(|e| e.into_iter().map(|f| f.into()).collect())
1152-
.collect();
1153-
1154-
let result = self.ctx.read_parquet(path, options);
1155-
let df = PyDataFrame::new(wait_for_future(py, result)??);
1067+
let options = build_parquet_options(
1068+
table_partition_cols,
1069+
parquet_pruning,
1070+
file_extension,
1071+
skip_metadata,
1072+
&schema,
1073+
file_sort_order,
1074+
);
1075+
let df = PyDataFrame::new(wait_for_future(py, self.ctx.read_parquet(path, options))??);
11561076
Ok(df)
11571077
}
11581078

@@ -1166,21 +1086,8 @@ impl PySessionContext {
11661086
file_extension: &str,
11671087
py: Python,
11681088
) -> PyDataFusionResult<PyDataFrame> {
1169-
let mut options = AvroReadOptions::default().table_partition_cols(
1170-
table_partition_cols
1171-
.into_iter()
1172-
.map(|(name, ty)| (name, ty.0))
1173-
.collect::<Vec<(String, DataType)>>(),
1174-
);
1175-
options.file_extension = file_extension;
1176-
let df = if let Some(schema) = schema {
1177-
options.schema = Some(&schema.0);
1178-
let read_future = self.ctx.read_avro(path, options);
1179-
wait_for_future(py, read_future)??
1180-
} else {
1181-
let read_future = self.ctx.read_avro(path, options);
1182-
wait_for_future(py, read_future)??
1183-
};
1089+
let options = build_avro_options(table_partition_cols, file_extension, &schema);
1090+
let df = wait_for_future(py, self.ctx.read_avro(path, options))??;
11841091
Ok(PyDataFrame::new(df))
11851092
}
11861093

@@ -1280,7 +1187,7 @@ impl PySessionContext {
12801187
// check if the file extension matches the expected extension
12811188
for path in &table_paths {
12821189
let file_path = path.as_str();
1283-
if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1190+
if !file_path.ends_with(option_extension.as_str()) && !path.is_collection() {
12841191
return exec_err!(
12851192
"File path '{file_path}' does not match the expected extension '{option_extension}'"
12861193
);
@@ -1321,6 +1228,80 @@ pub fn parse_file_compression_type(
13211228
})
13221229
}
13231230

1231+
fn convert_csv_options(
1232+
options: Option<&PyCsvReadOptions>,
1233+
) -> PyDataFusionResult<CsvReadOptions<'_>> {
1234+
Ok(options
1235+
.map(|opts| opts.try_into())
1236+
.transpose()?
1237+
.unwrap_or_default())
1238+
}
1239+
1240+
fn convert_partition_cols(
1241+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1242+
) -> Vec<(String, DataType)> {
1243+
table_partition_cols
1244+
.into_iter()
1245+
.map(|(name, ty)| (name, ty.0))
1246+
.collect()
1247+
}
1248+
1249+
fn convert_file_sort_order(
1250+
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1251+
) -> Vec<Vec<datafusion::logical_expr::SortExpr>> {
1252+
file_sort_order
1253+
.unwrap_or_default()
1254+
.into_iter()
1255+
.map(|e| e.into_iter().map(|f| f.into()).collect())
1256+
.collect()
1257+
}
1258+
1259+
fn build_parquet_options<'a>(
1260+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1261+
parquet_pruning: bool,
1262+
file_extension: &'a str,
1263+
skip_metadata: bool,
1264+
schema: &'a Option<PyArrowType<Schema>>,
1265+
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1266+
) -> ParquetReadOptions<'a> {
1267+
let mut options = ParquetReadOptions::default()
1268+
.table_partition_cols(convert_partition_cols(table_partition_cols))
1269+
.parquet_pruning(parquet_pruning)
1270+
.skip_metadata(skip_metadata);
1271+
options.file_extension = file_extension;
1272+
options.schema = schema.as_ref().map(|x| &x.0);
1273+
options.file_sort_order = convert_file_sort_order(file_sort_order);
1274+
options
1275+
}
1276+
1277+
fn build_json_options<'a>(
1278+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1279+
file_compression_type: Option<String>,
1280+
schema_infer_max_records: usize,
1281+
file_extension: &'a str,
1282+
schema: &'a Option<PyArrowType<Schema>>,
1283+
) -> Result<JsonReadOptions<'a>, PyErr> {
1284+
let mut options = JsonReadOptions::default()
1285+
.table_partition_cols(convert_partition_cols(table_partition_cols))
1286+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
1287+
options.schema_infer_max_records = schema_infer_max_records;
1288+
options.file_extension = file_extension;
1289+
options.schema = schema.as_ref().map(|x| &x.0);
1290+
Ok(options)
1291+
}
1292+
1293+
fn build_avro_options<'a>(
1294+
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1295+
file_extension: &'a str,
1296+
schema: &'a Option<PyArrowType<Schema>>,
1297+
) -> AvroReadOptions<'a> {
1298+
let mut options = AvroReadOptions::default()
1299+
.table_partition_cols(convert_partition_cols(table_partition_cols));
1300+
options.file_extension = file_extension;
1301+
options.schema = schema.as_ref().map(|x| &x.0);
1302+
options
1303+
}
1304+
13241305
impl From<PySessionContext> for SessionContext {
13251306
fn from(ctx: PySessionContext) -> SessionContext {
13261307
ctx.ctx.as_ref().clone()

0 commit comments

Comments
 (0)