Skip to content

Commit a0fc731

Browse files
committed
feat: enhance Arrow interoperability by exposing DataFrame results via C Stream interface and updating tests for RecordBatch compatibility
1 parent d4ddb3d commit a0fc731

File tree

5 files changed

+36
-12
lines changed

5 files changed

+36
-12
lines changed

python/datafusion/dataframe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,10 +1136,16 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11361136
def __iter__(self) -> Iterator[RecordBatch]:
11371137
"""Yield record batches from this DataFrame lazily.
11381138
1139-
This delegates to :py:meth:`to_stream` without eagerly materializing the
1140-
entire result set.
1139+
This implementation exposes DataFrame results via Arrow's C Stream
1140+
interface so that PyArrow consumers such as
1141+
:py:meth:`pyarrow.Table.from_batches` detect and use
1142+
:py:meth:`__arrow_c_stream__` instead of iterating row by row in
1143+
Python.
11411144
"""
1142-
return iter(self.to_stream())
1145+
import pyarrow as pa
1146+
1147+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
1148+
yield from reader
11431149

11441150
def __aiter__(self) -> AsyncIterator[RecordBatch]:
11451151
"""Asynchronously yield record batches from this DataFrame lazily."""

python/tests/test_arrow_interop.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pyarrow as pa
22
import pytest
33

4+
from .utils import range_table
5+
46

57
def test_table_from_batches_with_dataframe(ctx):
68
batch1 = pa.record_batch({"a": pa.array([1, 2]), "b": pa.array(["x", "y"])})
@@ -30,3 +32,19 @@ def test_table_from_batches_with_record_batch(ctx):
3032

3133
expected = pa.Table.from_batches([batch])
3234
assert table.equals(expected)
35+
36+
37+
def test_table_from_batches_with_range_table(ctx):
38+
df = range_table(ctx, 0, 5)
39+
40+
try:
41+
table = pa.Table.from_batches(df)
42+
except TypeError as err: # pragma: no cover - failure path
43+
pytest.fail(
44+
f"TypeError raised when converting range DataFrame to Arrow Table: {err}"
45+
)
46+
47+
# Create a schema with non-nullable field to match the actual output
48+
schema = pa.schema([pa.field("value", pa.int64(), nullable=False)])
49+
expected = pa.table({"value": pa.array(range(5), type=pa.int64())}, schema=schema)
50+
assert table.equals(expected)

python/tests/test_dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,8 +1595,8 @@ def test_iter_batches_dataframe(fail_collect):
15951595

15961596
expected = [batch1, batch2]
15971597
for got, exp in zip(df, expected):
1598-
assert isinstance(got, RecordBatch)
1599-
assert got.to_pyarrow().equals(exp)
1598+
assert isinstance(got, pa.RecordBatch)
1599+
assert got.equals(exp)
16001600

16011601

16021602
def test_table_from_batches_dataframe(df, fail_collect):
@@ -1605,7 +1605,7 @@ def test_table_from_batches_dataframe(df, fail_collect):
16051605
assert set(table.column_names) == {"a", "b", "c"}
16061606

16071607
for batch in df:
1608-
assert isinstance(batch, RecordBatch)
1608+
assert isinstance(batch, pa.RecordBatch)
16091609

16101610

16111611
def test_arrow_c_stream_to_table_and_reader(fail_collect):

python/tests/test_dataframe_iter_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
# under the License.
1717

1818

19-
from datafusion.record_batch import RecordBatch
19+
import pyarrow as pa
2020

2121

2222
def test_to_stream(ctx):
2323
df = ctx.from_pydict({"a": [1, 2]})
2424
stream = df.to_stream()
25-
batches = [rb.to_pyarrow() for rb in stream]
25+
batches = list(stream)
2626
assert len(batches) == 1
2727
assert batches[0].to_pydict() == {"a": [1, 2]}
2828

@@ -31,5 +31,5 @@ def test_dataframe_iter(ctx):
3131
df = ctx.from_pydict({"a": [1, 2]})
3232
batches = list(df)
3333
assert len(batches) == 1
34-
assert isinstance(batches[0], RecordBatch)
35-
assert batches[0].to_pyarrow().to_pydict() == {"a": [1, 2]}
34+
assert isinstance(batches[0], pa.RecordBatch)
35+
assert batches[0].to_pydict() == {"a": [1, 2]}

python/tests/test_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pyarrow as pa
2121
import pytest
22-
from datafusion import RecordBatch, column
22+
from datafusion import column
2323
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2424

2525
from .utils import range_table
@@ -133,4 +133,4 @@ def test_table_from_batches_stream(ctx, fail_collect):
133133
assert table.column_names == ["value"]
134134

135135
for batch in df:
136-
assert isinstance(batch, RecordBatch)
136+
assert isinstance(batch, pa.RecordBatch)

0 commit comments

Comments
 (0)