Skip to content

Commit d4ddb3d

Browse files
committed
feat: enhance Arrow interoperability in RecordBatch and improve capsule naming
1 parent 3475086 commit d4ddb3d

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

python/datafusion/record_batch.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,8 @@ def to_pyarrow(self) -> pa.RecordBatch:
4646
"""Convert to :py:class:`pa.RecordBatch`."""
4747
return self.record_batch.to_pyarrow()
4848

49-
def __arrow__(self, *args: object, **kwargs: object) -> pa.RecordBatch:
50-
"""Return a :py:class:`pa.RecordBatch` for Arrow interoperability.
51-
52-
This enables ``datafusion.record_batch.RecordBatch`` instances to be
53-
automatically recognized by PyArrow when passed to its APIs.
54-
"""
55-
return self.to_pyarrow()
49+
def __arrow__(self, type: object | None = None) -> pa.RecordBatch: # noqa: D105
50+
return self.record_batch.to_pyarrow()
5651

5752
def __arrow_c_array__(
5853
self, requested_schema: object | None = None
@@ -65,6 +60,29 @@ def __arrow_c_array__(
6560
schema_capsule, array_capsule = self.record_batch.__arrow_c_array__(
6661
requested_schema
6762
)
63+
64+
# Ensure the returned capsules are named as expected by PyArrow. The
65+
# Rust implementation already produces properly named capsules, but some
66+
# Python consumers (including PyArrow itself) are strict about the
67+
# capsule names matching ``"arrow_schema"`` and ``"arrow_array"``.
68+
import ctypes
69+
70+
def _ensure_name(capsule: object, name: bytes) -> None:
71+
pythonapi = ctypes.pythonapi
72+
get_name = pythonapi.PyCapsule_GetName
73+
get_name.restype = ctypes.c_char_p
74+
get_name.argtypes = [ctypes.py_object]
75+
76+
set_name = pythonapi.PyCapsule_SetName
77+
set_name.restype = ctypes.c_int
78+
set_name.argtypes = [ctypes.py_object, ctypes.c_char_p]
79+
80+
if get_name(capsule) != name:
81+
set_name(capsule, name)
82+
83+
_ensure_name(schema_capsule, b"arrow_schema")
84+
_ensure_name(array_capsule, b"arrow_array")
85+
6886
return schema_capsule, array_capsule
6987

7088

python/tests/test_arrow_interop.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,19 @@ def test_table_from_batches_with_dataframe(ctx):
1414

1515
expected = pa.Table.from_batches([batch1, batch2])
1616
assert table.equals(expected)
17+
18+
19+
def test_table_from_batches_with_record_batch(ctx):
20+
batch = pa.record_batch({"a": pa.array([1, 2]), "b": pa.array(["x", "y"])})
21+
df = ctx.create_dataframe([[batch]])
22+
rb = df.collect()[0]
23+
24+
try:
25+
table = pa.Table.from_batches([rb])
26+
except TypeError as err: # pragma: no cover - failure path
27+
pytest.fail(
28+
f"TypeError raised when converting RecordBatch to Arrow Table: {err}"
29+
)
30+
31+
expected = pa.Table.from_batches([batch])
32+
assert table.equals(expected)

0 commit comments

Comments
 (0)