Skip to content

Commit f464beb

Browse files
committed
Add test for collecting multiple record batches to PyArrow
1 parent af04660 commit f464beb

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

python/tests/test_dataframe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,27 @@ def test_collect_partitioned():
13841384
assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()
13851385

13861386

1387+
def test_collect_multiple_batches_to_pyarrow():
1388+
ctx = SessionContext()
1389+
1390+
batch1 = pa.RecordBatch.from_arrays(
1391+
[pa.array([1, 2])],
1392+
names=["a"],
1393+
)
1394+
batch2 = pa.RecordBatch.from_arrays(
1395+
[pa.array([3, 4])],
1396+
names=["a"],
1397+
)
1398+
1399+
df = ctx.create_dataframe([[batch1], [batch2]])
1400+
1401+
batches = df.collect()
1402+
1403+
assert len(batches) == 2
1404+
table = pa.Table.from_batches(batches)
1405+
assert table.column("a").to_pylist() == [1, 2, 3, 4]
1406+
1407+
13871408
def test_union(ctx):
13881409
batch = pa.RecordBatch.from_arrays(
13891410
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],

src/dataframe.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
use std::collections::HashMap;
1919
use std::ffi::CString;
20-
use std::ptr::addr_of;
2120
use std::sync::Arc;
2221

2322
use arrow::array::{new_null_array, Array, RecordBatch, RecordBatchReader, StructArray};
@@ -380,15 +379,30 @@ fn record_batches_to_pyarrow(
380379
ffi_batches
381380
.into_iter()
382381
.map(|(array, schema)| {
383-
record_batch_class
384-
.call_method1(
385-
"_import_from_c",
386-
(
387-
addr_of!(array) as Py_uintptr_t,
388-
addr_of!(schema) as Py_uintptr_t,
389-
),
390-
)
391-
.map(Into::into)
382+
// Allocate the FFI structures on the heap so that PyArrow can take
383+
// ownership of them. We intentionally leak these allocations on
384+
// success as PyArrow will release them when the resulting
385+
// `RecordBatch` is dropped on the Python side.
386+
let array = Box::new(array);
387+
let schema = Box::new(schema);
388+
let array_ptr = Box::into_raw(array);
389+
let schema_ptr = Box::into_raw(schema);
390+
391+
let result = record_batch_class.call_method1(
392+
"_import_from_c",
393+
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
394+
);
395+
396+
if result.is_err() {
397+
// If the import fails, reconstruct the boxes so they are
398+
// properly dropped to avoid leaking memory.
399+
unsafe {
400+
let _ = Box::from_raw(array_ptr);
401+
let _ = Box::from_raw(schema_ptr);
402+
}
403+
}
404+
405+
result.map(Into::into)
392406
})
393407
.collect()
394408
}

0 commit comments

Comments
 (0)