Skip to content

Commit b1e67df

Browse files
committed
fix: parse_compression_type once only
1 parent 8b0e2e1 commit b1e67df

File tree

1 file changed

+97
-96
lines changed

1 file changed

+97
-96
lines changed

src/context.rs

Lines changed: 97 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{
38-
py_datafusion_err, py_err_to_datafusion_err, PyDataFusionError, PyDataFusionResult,
39-
};
37+
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
4038
use crate::expr::sort_expr::PySortExpr;
4139
use crate::physical_plan::PyExecutionPlan;
4240
use crate::record_batch::PyRecordBatchStream;
@@ -756,10 +754,11 @@ impl PySessionContext {
756754
let delimiter_byte = delimiter_bytes[0];
757755

758756
// Validate file_compression_type synchronously before async call
759-
let fct = match parse_file_compression_type(file_compression_type.clone()) {
760-
Ok(compression) => compression,
761-
Err(err) => return Err(PyDataFusionError::PythonError(err)),
762-
};
757+
let parsed_file_compression_type =
758+
match parse_file_compression_type(file_compression_type.clone()) {
759+
Ok(compression) => compression,
760+
Err(err) => return Err(PyDataFusionError::PythonError(err)),
761+
};
763762

764763
// Clone all string references to create owned values
765764
let file_extension_owned = file_extension.to_string();
@@ -773,19 +772,18 @@ impl PySessionContext {
773772
// Clone self to avoid borrowing
774773
let self_clone = self.clone();
775774

776-
// Create options with owned values inside the async block
775+
// Create a future that uses our helper function
777776
let result_future = async move {
778-
let mut options = CsvReadOptions::new()
779-
.has_header(has_header)
780-
.delimiter(delimiter_byte)
781-
.schema_infer_max_records(schema_infer_max_records)
782-
.file_extension(&file_extension_owned)
783-
.file_compression_type(compression);
784-
785-
// Use owned schema if provided
786-
if let Some(s) = &schema_owned {
787-
options.schema = Some(s);
788-
}
777+
let schema_ref = schema_owned.as_ref();
778+
let options = create_csv_read_options(
779+
has_header,
780+
delimiter_byte,
781+
schema_infer_max_records,
782+
&file_extension_owned,
783+
parsed_file_compression_type,
784+
schema_ref,
785+
None, // No table partition columns here
786+
);
789787

790788
self_clone
791789
.register_csv_from_multiple_paths(&name_owned, paths, options)
@@ -798,17 +796,16 @@ impl PySessionContext {
798796

799797
// Create a future that moves owned values
800798
let result_future = async move {
801-
let mut options = CsvReadOptions::new()
802-
.has_header(has_header)
803-
.delimiter(delimiter_byte)
804-
.schema_infer_max_records(schema_infer_max_records)
805-
.file_extension(&file_extension_owned)
806-
.file_compression_type(compression);
807-
808-
// Use owned schema if provided
809-
if let Some(s) = &schema_owned {
810-
options.schema = Some(s);
811-
}
799+
let schema_ref = schema_owned.as_ref();
800+
let options = create_csv_read_options(
801+
has_header,
802+
delimiter_byte,
803+
schema_infer_max_records,
804+
&file_extension_owned,
805+
parsed_file_compression_type,
806+
schema_ref,
807+
None, // No table partition columns here
808+
);
812809

813810
ctx.register_csv(&name_owned, &path, options).await
814811
};
@@ -841,20 +838,17 @@ impl PySessionContext {
841838
.to_str()
842839
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
843840

844-
// Validate file_compression_type synchronously before async call
845-
if let Some(compression_type) = &file_compression_type {
846-
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
847-
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
848-
return Err(PyDataFusionError::PythonError(err));
849-
}
850-
}
841+
let parsed_file_compression_type =
842+
match parse_file_compression_type(file_compression_type.clone()) {
843+
Ok(compression) => compression,
844+
Err(err) => return Err(PyDataFusionError::PythonError(err)),
845+
};
851846

852847
// Clone required values to avoid borrowing in the future
853848
let ctx = self.ctx.clone();
854849
let name_owned = name.to_string();
855850
let path_owned = path.to_string();
856851
let file_extension_owned = file_extension.to_string();
857-
let file_compression_type_owned = file_compression_type.clone();
858852

859853
// Extract schema data if available to avoid borrowing
860854
let schema_owned = schema.map(|s| s.0.clone());
@@ -865,10 +859,7 @@ impl PySessionContext {
865859
// Create a future that moves owned values
866860
let result_future = async move {
867861
let mut options = NdJsonReadOptions::default()
868-
.file_compression_type(
869-
parse_file_compression_type(file_compression_type_owned)
870-
.map_err(py_err_to_datafusion_err)?,
871-
)
862+
.file_compression_type(parsed_file_compression_type)
872863
.table_partition_cols(table_partition_cols.clone());
873864
options.schema_infer_max_records = schema_infer_max_records;
874865
options.file_extension = &file_extension_owned;
@@ -1052,18 +1043,16 @@ impl PySessionContext {
10521043
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
10531044

10541045
// Validate file_compression_type synchronously before async call
1055-
if let Some(compression_type) = &file_compression_type {
1056-
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
1057-
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
1058-
return Err(PyDataFusionError::PythonError(err));
1059-
}
1060-
}
1046+
let parsed_file_compression_type =
1047+
match parse_file_compression_type(file_compression_type.clone()) {
1048+
Ok(compression) => compression,
1049+
Err(err) => return Err(PyDataFusionError::PythonError(err)),
1050+
};
10611051

10621052
// Clone required values to avoid borrowing in the future
10631053
let ctx = self.ctx.clone();
10641054
let path_owned = path.to_string();
10651055
let file_extension_owned = file_extension.to_string();
1066-
let file_compression_type_owned = file_compression_type.clone();
10671056

10681057
// Convert table partition columns
10691058
let table_partition_cols = convert_table_partition_cols(table_partition_cols)?;
@@ -1076,10 +1065,7 @@ impl PySessionContext {
10761065
let result_future = async move {
10771066
let mut options = NdJsonReadOptions::default()
10781067
.table_partition_cols(table_partition_cols.clone())
1079-
.file_compression_type(
1080-
parse_file_compression_type(file_compression_type_owned)
1081-
.map_err(py_err_to_datafusion_err)?,
1082-
);
1068+
.file_compression_type(parsed_file_compression_type);
10831069
options.schema_infer_max_records = schema_infer_max_records;
10841070
options.file_extension = &file_extension_owned;
10851071

@@ -1096,10 +1082,7 @@ impl PySessionContext {
10961082
let result_future = async move {
10971083
let mut options = NdJsonReadOptions::default()
10981084
.table_partition_cols(table_partition_cols.clone())
1099-
.file_compression_type(
1100-
parse_file_compression_type(file_compression_type_owned)
1101-
.map_err(py_err_to_datafusion_err)?,
1102-
);
1085+
.file_compression_type(parsed_file_compression_type);
11031086
options.schema_infer_max_records = schema_infer_max_records;
11041087
options.file_extension = &file_extension_owned;
11051088

@@ -1142,19 +1125,16 @@ impl PySessionContext {
11421125
// Store just the delimiter byte to use in the future
11431126
let delimiter_byte = delimiter_bytes[0];
11441127

1145-
// Validate file_compression_type synchronously before async call
1146-
if let Some(compression_type) = &file_compression_type {
1147-
// Return Python error directly instead of wrapping it in PyDataFusionError to match test expectations
1148-
if let Err(err) = parse_file_compression_type(Some(compression_type.clone())) {
1149-
return Err(PyDataFusionError::PythonError(err));
1150-
}
1151-
}
1128+
let parsed_file_compression_type =
1129+
match parse_file_compression_type(file_compression_type.clone()) {
1130+
Ok(compression) => compression,
1131+
Err(err) => return Err(PyDataFusionError::PythonError(err)),
1132+
};
11521133

11531134
// Clone required values to avoid borrowing in the future
11541135
let ctx = self.ctx.clone();
11551136
let file_extension_owned = file_extension.to_string();
11561137
let delimiter_owned = delimiter_byte; // Store just the delimiter byte
1157-
let file_compression_type_owned = file_compression_type.clone();
11581138

11591139
// Extract schema data if available to avoid borrowing
11601140
let schema_owned = schema.map(|s| s.0.clone());
@@ -1169,22 +1149,17 @@ impl PySessionContext {
11691149
let paths_owned = paths.clone();
11701150

11711151
let result_future = async move {
1172-
// Create options inside the future with owned values
1173-
let mut options = CsvReadOptions::new()
1174-
.has_header(has_header)
1175-
.delimiter(delimiter_owned)
1176-
.schema_infer_max_records(schema_infer_max_records)
1177-
.file_extension(&file_extension_owned)
1178-
.table_partition_cols(table_partition_cols.clone())
1179-
.file_compression_type(
1180-
parse_file_compression_type(file_compression_type_owned)
1181-
.map_err(py_err_to_datafusion_err)?,
1182-
);
1183-
1184-
// Use owned schema if provided
1185-
if let Some(s) = &schema_owned {
1186-
options.schema = Some(s);
1187-
}
1152+
// Create options using our helper function
1153+
let schema_ref = schema_owned.as_ref();
1154+
let options = create_csv_read_options(
1155+
has_header,
1156+
delimiter_owned,
1157+
schema_infer_max_records,
1158+
&file_extension_owned,
1159+
parsed_file_compression_type,
1160+
schema_ref,
1161+
Some(table_partition_cols.clone()),
1162+
);
11881163

11891164
ctx.read_csv(paths_owned, options).await
11901165
};
@@ -1199,21 +1174,16 @@ impl PySessionContext {
11991174

12001175
let result_future = async move {
12011176
// Create options inside the future with owned values
1202-
let mut options = CsvReadOptions::new()
1203-
.has_header(has_header)
1204-
.delimiter(delimiter_owned)
1205-
.schema_infer_max_records(schema_infer_max_records)
1206-
.file_extension(&file_extension_owned)
1207-
.table_partition_cols(table_partition_cols.clone())
1208-
.file_compression_type(
1209-
parse_file_compression_type(file_compression_type_owned)
1210-
.map_err(py_err_to_datafusion_err)?,
1211-
);
1212-
1213-
// Use owned schema if provided
1214-
if let Some(s) = &schema_owned {
1215-
options.schema = Some(s);
1216-
}
1177+
let schema_ref = schema_owned.as_ref();
1178+
let options = create_csv_read_options(
1179+
has_header,
1180+
delimiter_owned,
1181+
schema_infer_max_records,
1182+
&file_extension_owned,
1183+
parsed_file_compression_type,
1184+
schema_ref,
1185+
Some(table_partition_cols.clone()),
1186+
);
12171187

12181188
ctx.read_csv(path_owned, options).await
12191189
};
@@ -1423,6 +1393,37 @@ pub fn convert_table_partition_cols(
14231393
.collect::<Result<Vec<_>, _>>()
14241394
}
14251395

1396+
/// Create CsvReadOptions with the provided parameters
1397+
fn create_csv_read_options<'a>(
1398+
has_header: bool,
1399+
delimiter_byte: u8,
1400+
schema_infer_max_records: usize,
1401+
file_extension: &'a str,
1402+
file_compression_type: FileCompressionType,
1403+
schema: Option<&'a Schema>,
1404+
table_partition_cols: Option<Vec<(String, DataType)>>,
1405+
) -> CsvReadOptions<'a> {
1406+
let mut options = CsvReadOptions::new()
1407+
.has_header(has_header)
1408+
.delimiter(delimiter_byte)
1409+
.schema_infer_max_records(schema_infer_max_records)
1410+
.file_extension(file_extension);
1411+
1412+
// Set compression type
1413+
options = options.file_compression_type(file_compression_type);
1414+
1415+
// Set table partition columns if provided
1416+
if let Some(cols) = table_partition_cols {
1417+
options = options.table_partition_cols(cols);
1418+
}
1419+
1420+
// Set schema if provided
1421+
if let Some(s) = schema {
1422+
options.schema = Some(s);
1423+
}
1424+
options
1425+
}
1426+
14261427
pub fn parse_file_compression_type(
14271428
file_compression_type: Option<String>,
14281429
) -> Result<FileCompressionType, PyErr> {

0 commit comments

Comments
 (0)