Skip to content

Commit 5748ff0

Browse files
committed
Improve exception handling in wait_for_future
Updated wait_for_future to surface pending Python exceptions by executing bytecode during signal checks, ensuring that asynchronous interrupts are processed promptly. Enhanced PartitionedDataFrameStreamReader to cancel remaining partition streams on projection errors or Python interrupts, allowing for clean iteration stops. Added regression tests to validate interrupted Arrow C stream reads and improve timing for RecordBatchReader.read_all cancellations.
1 parent b991f77 commit 5748ff0

File tree

3 files changed

+103
-6
lines changed

3 files changed

+103
-6
lines changed

python/tests/test_dataframe.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,17 +3349,92 @@ def read_stream():
33493349
pytest.fail("Stream read operation timed out after 10 seconds")
33503350

33513351
# Check if we got the expected KeyboardInterrupt
3352-
if read_exception and isinstance(read_exception[0], type) and read_exception[0] == KeyboardInterrupt:
3353-
interrupted = True
3354-
elif read_exception:
3355-
interrupt_error = read_exception[0]
3352+
if read_exception:
3353+
if isinstance(read_exception[0], type) and read_exception[0] == KeyboardInterrupt:
3354+
interrupted = True
3355+
elif "KeyboardInterrupt" in str(read_exception[0]):
3356+
interrupted = True
3357+
else:
3358+
interrupt_error = read_exception[0]
33563359

33573360
if not interrupted:
33583361
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
33593362

33603363
interrupt_thread.join(timeout=1.0)
33613364

33623365

3366+
def test_record_batch_reader_interrupt_exits_quickly(ctx):
3367+
df = ctx.sql(
3368+
"""
3369+
SELECT t1.value AS a, t2.value AS a2
3370+
FROM range(0, 1000000, 1) AS t1
3371+
JOIN range(0, 1000000, 1) AS t2 ON t1.value = t2.value
3372+
"""
3373+
)
3374+
3375+
reader = pa.RecordBatchReader.from_stream(df)
3376+
3377+
query_started = threading.Event()
3378+
read_thread_id = None
3379+
interrupt_time = None
3380+
completion_time = None
3381+
read_exception = []
3382+
3383+
def trigger_interrupt():
3384+
nonlocal interrupt_time
3385+
if not query_started.wait(timeout=5.0):
3386+
pytest.fail("Query did not start in time")
3387+
3388+
time.sleep(0.1)
3389+
interrupt_time = time.time()
3390+
3391+
if read_thread_id is None:
3392+
pytest.fail("Read thread did not record an identifier")
3393+
3394+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
3395+
ctypes.c_long(read_thread_id), ctypes.py_object(KeyboardInterrupt)
3396+
)
3397+
if res != 1:
3398+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
3399+
ctypes.c_long(read_thread_id), ctypes.py_object(0)
3400+
)
3401+
pytest.fail("Failed to raise KeyboardInterrupt in read thread")
3402+
3403+
def read_stream():
3404+
nonlocal read_thread_id, completion_time
3405+
read_thread_id = threading.get_ident()
3406+
try:
3407+
query_started.set()
3408+
reader.read_all()
3409+
except KeyboardInterrupt:
3410+
completion_time = time.time()
3411+
except Exception as exc: # pragma: no cover - unexpected failure path
3412+
completion_time = time.time()
3413+
read_exception.append(exc)
3414+
3415+
read_thread = threading.Thread(target=read_stream, daemon=True)
3416+
interrupt_thread = threading.Thread(target=trigger_interrupt, daemon=True)
3417+
3418+
read_thread.start()
3419+
interrupt_thread.start()
3420+
3421+
read_thread.join(timeout=10.0)
3422+
if read_thread.is_alive():
3423+
pytest.fail("Stream read operation timed out after 10 seconds")
3424+
3425+
interrupt_thread.join(timeout=1.0)
3426+
3427+
if read_exception and "KeyboardInterrupt" not in str(read_exception[0]):
3428+
pytest.fail(f"Read thread raised unexpected exception: {read_exception[0]}")
3429+
3430+
assert completion_time is not None, "Read thread did not finish"
3431+
assert interrupt_time is not None, "Interrupt was not sent"
3432+
3433+
elapsed = completion_time - interrupt_time
3434+
assert elapsed >= 0, "Completion recorded before interrupt was sent"
3435+
assert elapsed < 1.5, f"Cancellation took too long: {elapsed}s"
3436+
3437+
33633438
def test_show_select_where_no_rows(capsys) -> None:
33643439
ctx = SessionContext()
33653440
df = ctx.sql("SELECT 1 WHERE 1=0")

src/dataframe.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,13 @@ struct PartitionedDataFrameStreamReader {
375375
current: usize,
376376
}
377377

378+
impl PartitionedDataFrameStreamReader {
379+
fn cancel_streams(&mut self) {
380+
self.streams.drain(self.current..);
381+
self.current = self.streams.len();
382+
}
383+
}
384+
378385
impl Iterator for PartitionedDataFrameStreamReader {
379386
type Item = Result<RecordBatch, ArrowError>;
380387

@@ -389,7 +396,10 @@ impl Iterator for PartitionedDataFrameStreamReader {
389396
let batch = if let Some(ref schema) = self.projection {
390397
match record_batch_into_schema(batch, schema.as_ref()) {
391398
Ok(b) => b,
392-
Err(e) => return Some(Err(e)),
399+
Err(e) => {
400+
self.cancel_streams();
401+
return Some(Err(e));
402+
}
393403
}
394404
} else {
395405
batch
@@ -401,9 +411,11 @@ impl Iterator for PartitionedDataFrameStreamReader {
401411
continue;
402412
}
403413
Ok(Err(e)) => {
414+
self.cancel_streams();
404415
return Some(Err(ArrowError::ExternalError(Box::new(e))));
405416
}
406417
Err(e) => {
418+
self.cancel_streams();
407419
return Some(Err(ArrowError::ExternalError(Box::new(e))));
408420
}
409421
}

src/utils.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::ffi::CString;
1819
use std::future::Future;
1920
use std::sync::{Arc, OnceLock};
2021
use std::time::Duration;
@@ -27,6 +28,7 @@ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2728
use pyo3::exceptions::PyValueError;
2829
use pyo3::prelude::*;
2930
use pyo3::types::PyCapsule;
31+
use pyo3::PyErr;
3032
use tokio::runtime::Runtime;
3133
use tokio::task::JoinHandle;
3234
use tokio::time::sleep;
@@ -84,7 +86,15 @@ where
8486
tokio::select! {
8587
res = &mut fut => break Ok(res),
8688
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
87-
Python::attach(|py| py.check_signals())?;
89+
Python::attach(|py| {
90+
if let Some(err) = PyErr::take(py) {
91+
Err(err)
92+
} else {
93+
let code = CString::new("pass").unwrap();
94+
py.run(code.as_c_str(), None, None)?;
95+
py.check_signals()
96+
}
97+
})?;
8898
}
8999
}
90100
}

0 commit comments

Comments
 (0)