Skip to content

Commit 51e1267

Browse files
committed
Handle Python interrupts
1 parent 02cc9ae commit 51e1267

File tree

5 files changed

+60
-75
lines changed

5 files changed

+60
-75
lines changed

src/catalog.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ impl PyDatabase {
9797
}
9898

9999
fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
100-
if let Some(table) = wait_for_future(py, self.database.table(name))? {
100+
if let Some(table) = wait_for_future(py, self.database.table(name))?? {
101101
Ok(PyTable::new(table))
102102
} else {
103103
Err(PyDataFusionError::Common(format!(

src/context.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ impl PySessionContext {
375375
None => {
376376
let state = self.ctx.state();
377377
let schema = options.infer_schema(&state, &table_path);
378-
wait_for_future(py, schema)?
378+
wait_for_future(py, schema)??
379379
}
380380
};
381381
let config = ListingTableConfig::new(table_path)
@@ -400,7 +400,7 @@ impl PySessionContext {
400400
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
401401
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
402402
let result = self.ctx.sql(query);
403-
let df = wait_for_future(py, result)?;
403+
let df = wait_for_future(py, result)??;
404404
Ok(PyDataFrame::new(df))
405405
}
406406

@@ -417,7 +417,7 @@ impl PySessionContext {
417417
SQLOptions::new()
418418
};
419419
let result = self.ctx.sql_with_options(query, options);
420-
let df = wait_for_future(py, result)?;
420+
let df = wait_for_future(py, result)??;
421421
Ok(PyDataFrame::new(df))
422422
}
423423

@@ -451,7 +451,7 @@ impl PySessionContext {
451451

452452
self.ctx.register_table(&*table_name, Arc::new(table))?;
453453

454-
let table = wait_for_future(py, self._table(&table_name))?;
454+
let table = wait_for_future(py, self._table(&table_name))??;
455455

456456
let df = PyDataFrame::new(table);
457457
Ok(df)
@@ -650,7 +650,7 @@ impl PySessionContext {
650650
.collect();
651651

652652
let result = self.ctx.register_parquet(name, path, options);
653-
wait_for_future(py, result)?;
653+
wait_for_future(py, result)??;
654654
Ok(())
655655
}
656656

@@ -693,11 +693,11 @@ impl PySessionContext {
693693
if path.is_instance_of::<PyList>() {
694694
let paths = path.extract::<Vec<String>>()?;
695695
let result = self.register_csv_from_multiple_paths(name, paths, options);
696-
wait_for_future(py, result)?;
696+
wait_for_future(py, result)??;
697697
} else {
698698
let path = path.extract::<String>()?;
699699
let result = self.ctx.register_csv(name, &path, options);
700-
wait_for_future(py, result)?;
700+
wait_for_future(py, result)??;
701701
}
702702

703703
Ok(())
@@ -734,7 +734,7 @@ impl PySessionContext {
734734
options.schema = schema.as_ref().map(|x| &x.0);
735735

736736
let result = self.ctx.register_json(name, path, options);
737-
wait_for_future(py, result)?;
737+
wait_for_future(py, result)??;
738738

739739
Ok(())
740740
}
@@ -764,7 +764,7 @@ impl PySessionContext {
764764
options.schema = schema.as_ref().map(|x| &x.0);
765765

766766
let result = self.ctx.register_avro(name, path, options);
767-
wait_for_future(py, result)?;
767+
wait_for_future(py, result)??;
768768

769769
Ok(())
770770
}
@@ -826,7 +826,8 @@ impl PySessionContext {
826826

827827
pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
828828
let x = wait_for_future(py, self.ctx.table(name))
829-
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
829+
.map_err(|e| PyKeyError::new_err(e.to_string()))?
830+
.map_err(py_datafusion_err)?;
830831
Ok(PyDataFrame::new(x))
831832
}
832833

@@ -865,10 +866,10 @@ impl PySessionContext {
865866
let df = if let Some(schema) = schema {
866867
options.schema = Some(&schema.0);
867868
let result = self.ctx.read_json(path, options);
868-
wait_for_future(py, result)?
869+
wait_for_future(py, result)??
869870
} else {
870871
let result = self.ctx.read_json(path, options);
871-
wait_for_future(py, result)?
872+
wait_for_future(py, result)??
872873
};
873874
Ok(PyDataFrame::new(df))
874875
}
@@ -915,12 +916,12 @@ impl PySessionContext {
915916
let paths = path.extract::<Vec<String>>()?;
916917
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
917918
let result = self.ctx.read_csv(paths, options);
918-
let df = PyDataFrame::new(wait_for_future(py, result)?);
919+
let df = PyDataFrame::new(wait_for_future(py, result)??);
919920
Ok(df)
920921
} else {
921922
let path = path.extract::<String>()?;
922923
let result = self.ctx.read_csv(path, options);
923-
let df = PyDataFrame::new(wait_for_future(py, result)?);
924+
let df = PyDataFrame::new(wait_for_future(py, result)??);
924925
Ok(df)
925926
}
926927
}
@@ -958,7 +959,7 @@ impl PySessionContext {
958959
.collect();
959960

960961
let result = self.ctx.read_parquet(path, options);
961-
let df = PyDataFrame::new(wait_for_future(py, result)?);
962+
let df = PyDataFrame::new(wait_for_future(py, result)??);
962963
Ok(df)
963964
}
964965

@@ -978,10 +979,10 @@ impl PySessionContext {
978979
let df = if let Some(schema) = schema {
979980
options.schema = Some(&schema.0);
980981
let read_future = self.ctx.read_avro(path, options);
981-
wait_for_future(py, read_future)?
982+
wait_for_future(py, read_future)??
982983
} else {
983984
let read_future = self.ctx.read_avro(path, options);
984-
wait_for_future(py, read_future)?
985+
wait_for_future(py, read_future)??
985986
};
986987
Ok(PyDataFrame::new(df))
987988
}
@@ -1021,8 +1022,8 @@ impl PySessionContext {
10211022
let plan = plan.plan.clone();
10221023
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
10231024
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1024-
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1025-
Ok(PyRecordBatchStream::new(stream?))
1025+
let stream = wait_for_future(py, async { fut.await.expect("Tokio task panicked") })??;
1026+
Ok(PyRecordBatchStream::new(stream))
10261027
}
10271028
}
10281029

src/dataframe.rs

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ impl PyDataFrame {
233233
let (batches, has_more) = wait_for_future(
234234
py,
235235
collect_record_batches_to_display(self.df.as_ref().clone(), config),
236-
)?;
236+
)??;
237237
if batches.is_empty() {
238238
// This should not be reached, but do it for safety since we index into the vector below
239239
return Ok("No data to display".to_string());
@@ -256,7 +256,7 @@ impl PyDataFrame {
256256
let (batches, has_more) = wait_for_future(
257257
py,
258258
collect_record_batches_to_display(self.df.as_ref().clone(), config),
259-
)?;
259+
)??;
260260
if batches.is_empty() {
261261
// This should not be reached, but do it for safety since we index into the vector below
262262
return Ok("No data to display".to_string());
@@ -288,7 +288,7 @@ impl PyDataFrame {
288288
/// Calculate summary statistics for a DataFrame
289289
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
290290
let df = self.df.as_ref().clone();
291-
let stat_df = wait_for_future(py, df.describe())?;
291+
let stat_df = wait_for_future(py, df.describe())??;
292292
Ok(Self::new(stat_df))
293293
}
294294

@@ -400,16 +400,15 @@ impl PyDataFrame {
400400

401401
/// Cache DataFrame.
402402
fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
403-
let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
403+
let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
404404
Ok(Self::new(df))
405405
}
406406

407407
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
408408
/// maintaining the input partitioning.
409409
fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
410-
let batches =
411-
wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
412-
.map_err(PyDataFusionError::from)?;
410+
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
411+
.map_err(PyDataFusionError::from)?;
413412

414413
batches
415414
.into_iter()
@@ -512,7 +511,7 @@ impl PyDataFrame {
512511

513512
/// Get the execution plan for this `DataFrame`
514513
fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
515-
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
514+
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
516515
Ok(plan.into())
517516
}
518517

@@ -625,7 +624,7 @@ impl PyDataFrame {
625624
DataFrameWriteOptions::new(),
626625
Some(csv_options),
627626
),
628-
)?;
627+
)??;
629628
Ok(())
630629
}
631630

@@ -686,7 +685,7 @@ impl PyDataFrame {
686685
DataFrameWriteOptions::new(),
687686
Option::from(options),
688687
),
689-
)?;
688+
)??;
690689
Ok(())
691690
}
692691

@@ -698,7 +697,7 @@ impl PyDataFrame {
698697
.as_ref()
699698
.clone()
700699
.write_json(path, DataFrameWriteOptions::new(), None),
701-
)?;
700+
)??;
702701
Ok(())
703702
}
704703

@@ -721,7 +720,7 @@ impl PyDataFrame {
721720
py: Python<'py>,
722721
requested_schema: Option<Bound<'py, PyCapsule>>,
723722
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
724-
let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
723+
let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
725724
let mut schema: Schema = self.df.schema().to_owned().into();
726725

727726
if let Some(schema_capsule) = requested_schema {
@@ -754,8 +753,8 @@ impl PyDataFrame {
754753
let df = self.df.as_ref().clone();
755754
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
756755
rt.spawn(async move { df.execute_stream().await });
757-
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
758-
Ok(PyRecordBatchStream::new(stream?))
756+
let stream = wait_for_future(py, async { fut.await.expect("Tokio task panicked") })??;
757+
Ok(PyRecordBatchStream::new(stream))
759758
}
760759

761760
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
@@ -764,14 +763,10 @@ impl PyDataFrame {
764763
let df = self.df.as_ref().clone();
765764
let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
766765
rt.spawn(async move { df.execute_stream_partitioned().await });
767-
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
766+
let stream = wait_for_future(py, async { fut.await.expect("Tokio task panicked") })?
767+
.map_err(py_datafusion_err)?;
768768

769-
match stream {
770-
Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
771-
_ => Err(PyValueError::new_err(
772-
"Unable to execute stream partitioned",
773-
)),
774-
}
769+
Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
775770
}
776771

777772
/// Convert to pandas dataframe with pyarrow
@@ -816,7 +811,7 @@ impl PyDataFrame {
816811

817812
// Executes this DataFrame to get the total number of rows.
818813
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
819-
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
814+
Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
820815
}
821816

822817
/// Fill null values with a specified value for specific columns
@@ -842,7 +837,7 @@ impl PyDataFrame {
842837
/// Print DataFrame
843838
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
844839
// Get string representation of record batches
845-
let batches = wait_for_future(py, df.collect())?;
840+
let batches = wait_for_future(py, df.collect())??;
846841
let batches_as_string = pretty::pretty_format_batches(&batches);
847842
let result = match batches_as_string {
848843
Ok(batch) => format!("DataFrame()\n{batch}"),

src/substrait.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl PySubstraitSerializer {
7272
path: &str,
7373
py: Python,
7474
) -> PyDataFusionResult<()> {
75-
wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?;
75+
wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))??;
7676
Ok(())
7777
}
7878

@@ -94,19 +94,20 @@ impl PySubstraitSerializer {
9494
ctx: PySessionContext,
9595
py: Python,
9696
) -> PyDataFusionResult<PyObject> {
97-
let proto_bytes: Vec<u8> = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))?;
97+
let proto_bytes: Vec<u8> =
98+
wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))??;
9899
Ok(PyBytes::new(py, &proto_bytes).into())
99100
}
100101

101102
#[staticmethod]
102103
pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult<PyPlan> {
103-
let plan = wait_for_future(py, serializer::deserialize(path))?;
104+
let plan = wait_for_future(py, serializer::deserialize(path))??;
104105
Ok(PyPlan { plan: *plan })
105106
}
106107

107108
#[staticmethod]
108109
pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) -> PyDataFusionResult<PyPlan> {
109-
let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))?;
110+
let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))??;
110111
Ok(PyPlan { plan: *plan })
111112
}
112113
}
@@ -143,7 +144,7 @@ impl PySubstraitConsumer {
143144
) -> PyDataFusionResult<PyLogicalPlan> {
144145
let session_state = ctx.ctx.state();
145146
let result = consumer::from_substrait_plan(&session_state, &plan.plan);
146-
let logical_plan = wait_for_future(py, result)?;
147+
let logical_plan = wait_for_future(py, result)??;
147148
Ok(PyLogicalPlan::new(logical_plan))
148149
}
149150
}

0 commit comments

Comments
 (0)