Skip to content

Commit f459c60

Browse files
committed
test: add test for Arrow C stream capsule ownership
1 parent 9008bd7 commit f459c60

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

python/tests/test_dataframe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import ctypes
1818
import datetime
19+
import gc
1920
import os
2021
import re
2122
import threading
@@ -1620,6 +1621,22 @@ def test_arrow_c_stream_to_table_and_reader(fail_collect):
16201621
assert reader_table.equals(expected)
16211622

16221623

1624+
def test_arrow_c_stream_capsule_ownership(fail_collect):
1625+
ctx = SessionContext()
1626+
1627+
batch = pa.record_batch([pa.array([1])], names=["a"])
1628+
df = ctx.create_dataframe([[batch]])
1629+
1630+
capsule = df.__arrow_c_stream__()
1631+
reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
1632+
del capsule
1633+
gc.collect()
1634+
1635+
table = pa.Table.from_batches(reader)
1636+
expected = pa.Table.from_batches([batch])
1637+
assert table.equals(expected)
1638+
1639+
16231640
def test_arrow_c_stream_order():
16241641
ctx = SessionContext()
16251642

src/dataframe.rs

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::collections::HashMap;
19-
use std::ffi::{c_void, CStr, CString};
19+
use std::ffi::{CStr, CString};
2020
use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
@@ -39,7 +39,6 @@ use datafusion::prelude::*;
3939
use datafusion_ffi::table_provider::FFI_TableProvider;
4040
use futures::{StreamExt, TryStreamExt};
4141
use pyo3::exceptions::PyValueError;
42-
use pyo3::ffi;
4342
use pyo3::prelude::*;
4443
use pyo3::pybacked::PyBackedStr;
4544
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -967,28 +966,12 @@ impl PyDataFrame {
967966
};
968967
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
969968

970-
let stream = Box::new(FFI_ArrowArrayStream::new(reader));
971-
let stream_ptr = Box::into_raw(stream);
972-
debug_assert!(
973-
!stream_ptr.is_null(),
974-
"ArrowArrayStream pointer should never be null",
975-
);
969+
let stream = FFI_ArrowArrayStream::new(reader);
976970
// The returned capsule allows zero-copy hand-off to PyArrow. When
977-
// PyArrow imports the capsule it assumes ownership of the stream.
978-
let capsule = unsafe {
979-
ffi::PyCapsule_New(
980-
stream_ptr as *mut c_void,
981-
ARROW_STREAM_NAME.as_ptr(),
982-
None,
983-
)
984-
};
985-
if capsule.is_null() {
986-
unsafe { drop(Box::from_raw(stream_ptr)) };
987-
Err(PyErr::fetch(py).into())
988-
} else {
989-
let any = unsafe { Bound::from_owned_ptr(py, capsule) };
990-
Ok(any.downcast_into::<PyCapsule>().unwrap())
991-
}
971+
// PyArrow imports the capsule it assumes ownership of the stream and
972+
// nulls out the capsule's internal pointer so the destructor does not
973+
// free it twice.
974+
PyCapsule::new(py, stream, Some(ARROW_STREAM_NAME)).map_err(Into::into)
992975
}
993976

994977
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {

0 commit comments

Comments
 (0)