Skip to content

Commit 02cc9ae

Browse files
committed
Enable interrupt handling
1 parent eb2eb63 commit 02cc9ae

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

src/dataframe.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl PyDataFrame {
391391
/// Unless some order is specified in the plan, there is no
392392
/// guarantee of the order of the result.
393393
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
394-
let batches = wait_for_future(py, self.df.as_ref().clone().collect())
394+
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
395395
.map_err(PyDataFusionError::from)?;
396396
// cannot use PyResult<Vec<RecordBatch>> return type due to
397397
// https://github.com/PyO3/pyo3/issues/1813
@@ -407,8 +407,9 @@ impl PyDataFrame {
407407
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
408408
/// maintaining the input partitioning.
409409
fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
410-
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
411-
.map_err(PyDataFusionError::from)?;
410+
let batches =
411+
wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
412+
.map_err(PyDataFusionError::from)?;
412413

413414
batches
414415
.into_iter()

src/record_batch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl PyRecordBatchStream {
6363
impl PyRecordBatchStream {
6464
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
6565
let stream = self.stream.clone();
66-
wait_for_future(py, next_stream(stream, true))
66+
wait_for_future(py, next_stream(stream, true))?
6767
}
6868

6969
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {

src/utils.rs

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion::logical_expr::Volatility;
2424
use pyo3::exceptions::PyValueError;
2525
use pyo3::prelude::*;
2626
use pyo3::types::PyCapsule;
27+
use pyo3::PyErr;
2728
use std::future::Future;
2829
use std::sync::OnceLock;
2930
use tokio::runtime::Runtime;
@@ -47,14 +48,44 @@ pub(crate) fn get_global_ctx() -> &'static SessionContext {
4748
CTX.get_or_init(SessionContext::new)
4849
}
4950

50-
/// Utility to collect rust futures with GIL released
51-
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
51+
/// Utility to collect rust futures with GIL released and respond to
52+
/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
53+
/// received while the future is running, the future is aborted and the
54+
/// corresponding Python exception is raised.
55+
pub fn wait_for_future<F>(py: Python, f: F) -> PyResult<F::Output>
5256
where
53-
F: Future + Send,
54-
F::Output: Send,
57+
F: Future + Send + 'static,
58+
F::Output: Send + 'static,
5559
{
60+
use std::{thread, time::Duration};
61+
use tokio::task::JoinHandle;
62+
5663
let runtime: &Runtime = &get_tokio_runtime().0;
57-
py.allow_threads(|| runtime.block_on(f))
64+
65+
// Spawn the future so it can be aborted if a signal is received
66+
let handle: JoinHandle<F::Output> = runtime.spawn(f);
67+
68+
let mut interrupt: Option<PyErr> = None;
69+
py.allow_threads(|| {
70+
while !handle.is_finished() {
71+
thread::sleep(Duration::from_millis(10));
72+
Python::with_gil(|py| {
73+
if let Err(err) = py.check_signals() {
74+
handle.abort();
75+
interrupt = Some(err);
76+
}
77+
});
78+
if interrupt.is_some() {
79+
break;
80+
}
81+
}
82+
});
83+
84+
if let Some(err) = interrupt {
85+
return Err(err);
86+
}
87+
88+
Ok(runtime.block_on(handle).expect("Tokio task panicked"))
5889
}
5990

6091
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {

0 commit comments

Comments
 (0)