Skip to content

Commit 8d76799

Browse files
committed
Revert "UNPICK"
This reverts commit a029ce2.
1 parent a029ce2 commit 8d76799

File tree

11 files changed

+587
-55
lines changed

11 files changed

+587
-55
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,39 @@ To materialize the results of your DataFrame operations:
145145
146146
# Display results
147147
df.show() # Print tabular format to console
148-
148+
149149
# Count rows
150150
count = df.count()
151151
152+
PyArrow Streaming
153+
-----------------
154+
155+
DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling
156+
zero-copy streaming into libraries like `PyArrow <https://arrow.apache.org/>`_.
157+
Earlier versions eagerly converted the entire DataFrame when exporting to
158+
PyArrow, which could exhaust memory on large datasets. With streaming, batches
159+
are produced lazily so you can process arbitrarily large results without
160+
out-of-memory errors.
161+
162+
.. code-block:: python
163+
164+
import pyarrow as pa
165+
166+
# Create a PyArrow RecordBatchReader without materializing all batches
167+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
168+
for batch in reader:
169+
... # process each batch as it is produced
170+
171+
DataFrames are also iterable, yielding :class:`pyarrow.RecordBatch` objects
172+
lazily so you can loop over results directly:
173+
174+
.. code-block:: python
175+
176+
for batch in df:
177+
... # process each batch as it is produced
178+
179+
See :doc:`../io/arrow` for additional details on the Arrow interface.
180+
152181
HTML Rendering
153182
--------------
154183

python/datafusion/dataframe.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TYPE_CHECKING,
2727
Any,
2828
Iterable,
29+
Iterator,
2930
Literal,
3031
Optional,
3132
Union,
@@ -289,6 +290,9 @@ def __init__(
289290
class DataFrame:
290291
"""Two dimensional table representation of data.
291292
293+
DataFrame objects are iterable; iterating over a DataFrame yields
294+
:class:`pyarrow.RecordBatch` instances lazily.
295+
292296
See :ref:`user_guide_concepts` in the online documentation for more information.
293297
"""
294298

@@ -1098,21 +1102,42 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
10981102
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))
10991103

11001104
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
1101-
"""Export an Arrow PyCapsule Stream.
1105+
"""Export the DataFrame as an Arrow C Stream.
11021106
1103-
This will execute and collect the DataFrame. We will attempt to respect the
1104-
requested schema, but only trivial transformations will be applied such as only
1105-
returning the fields listed in the requested schema if their data types match
1106-
those in the DataFrame.
1107+
The DataFrame is executed using DataFusion's streaming APIs and exposed via
1108+
Arrow's C Stream interface. Record batches are produced incrementally, so the
1109+
full result set is never materialized in memory. When ``requested_schema`` is
1110+
provided, only straightforward projections such as column selection or
1111+
reordering are applied.
11071112
11081113
Args:
11091114
requested_schema: Attempt to provide the DataFrame using this schema.
11101115
11111116
Returns:
1112-
Arrow PyCapsule object.
1117+
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11131118
"""
1119+
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1120+
# ``execute_stream_partitioned`` under the hood to stream batches while
1121+
# preserving the original partition order.
11141122
return self.df.__arrow_c_stream__(requested_schema)
11151123

1124+
def __iter__(self) -> Iterator[pa.RecordBatch]:
1125+
"""Yield record batches from the DataFrame without materializing results.
1126+
1127+
This implementation streams record batches via the Arrow C Stream
1128+
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
1129+
consume results lazily. The DataFrame is executed using DataFusion's
1130+
partitioned streaming APIs so ``collect`` is never invoked and batch
1131+
order across partitions is preserved.
1132+
"""
1133+
from contextlib import closing
1134+
1135+
import pyarrow as pa
1136+
1137+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
1138+
with closing(reader):
1139+
yield from reader
1140+
11161141
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11171142
"""Apply a function to the current DataFrame which returns another DataFrame.
11181143

python/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import SessionContext
20+
from datafusion import DataFrame, SessionContext
2121
from pyarrow.csv import write_csv
2222

2323

@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
4949
delimiter=",",
5050
schema_infer_max_records=10,
5151
)
52+
53+
54+
@pytest.fixture
55+
def fail_collect(monkeypatch):
56+
def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path
57+
msg = "collect should not be called"
58+
raise AssertionError(msg)
59+
60+
monkeypatch.setattr(DataFrame, "collect", _fail_collect)

python/tests/test_dataframe.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from datafusion.expr import Window
4747
from pyarrow.csv import write_csv
4848

49+
pa_cffi = pytest.importorskip("pyarrow.cffi")
50+
4951
MB = 1024 * 1024
5052

5153

@@ -1582,6 +1584,99 @@ def test_empty_to_arrow_table(df):
15821584
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831585

15841586

1587+
def test_iter_batches_dataframe(fail_collect):
1588+
ctx = SessionContext()
1589+
1590+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1591+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1592+
df = ctx.create_dataframe([[batch1], [batch2]])
1593+
1594+
expected = [batch1, batch2]
1595+
for got, exp in zip(df, expected):
1596+
assert got.equals(exp)
1597+
1598+
1599+
def test_arrow_c_stream_to_table(fail_collect):
1600+
ctx = SessionContext()
1601+
1602+
# Create a DataFrame with two separate record batches
1603+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1604+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1605+
df = ctx.create_dataframe([[batch1], [batch2]])
1606+
1607+
table = pa.Table.from_batches(df)
1608+
batches = table.to_batches()
1609+
1610+
assert len(batches) == 2
1611+
assert batches[0].equals(batch1)
1612+
assert batches[1].equals(batch2)
1613+
assert table.schema == df.schema()
1614+
assert table.column("a").num_chunks == 2
1615+
1616+
1617+
def test_arrow_c_stream_order():
1618+
ctx = SessionContext()
1619+
1620+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1621+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1622+
1623+
df = ctx.create_dataframe([[batch1, batch2]])
1624+
1625+
table = pa.Table.from_batches(df)
1626+
expected = pa.Table.from_batches([batch1, batch2])
1627+
1628+
assert table.equals(expected)
1629+
col = table.column("a")
1630+
assert col.chunk(0)[0].as_py() == 1
1631+
assert col.chunk(1)[0].as_py() == 2
1632+
1633+
1634+
def test_arrow_c_stream_reader(df):
1635+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
1636+
assert isinstance(reader, pa.RecordBatchReader)
1637+
table = pa.Table.from_batches(reader)
1638+
expected = pa.Table.from_batches(df.collect())
1639+
assert table.equals(expected)
1640+
1641+
1642+
def test_arrow_c_stream_schema_selection(fail_collect):
1643+
ctx = SessionContext()
1644+
1645+
batch = pa.RecordBatch.from_arrays(
1646+
[
1647+
pa.array([1, 2]),
1648+
pa.array([3, 4]),
1649+
pa.array([5, 6]),
1650+
],
1651+
names=["a", "b", "c"],
1652+
)
1653+
df = ctx.create_dataframe([[batch]])
1654+
1655+
requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())])
1656+
1657+
c_schema = pa_cffi.ffi.new("struct ArrowSchema*")
1658+
address = int(pa_cffi.ffi.cast("uintptr_t", c_schema))
1659+
requested_schema._export_to_c(address)
1660+
capsule_new = ctypes.pythonapi.PyCapsule_New
1661+
capsule_new.restype = ctypes.py_object
1662+
capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
1663+
schema_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None)
1664+
1665+
reader = pa.RecordBatchReader._import_from_c_capsule(
1666+
df.__arrow_c_stream__(schema_capsule)
1667+
)
1668+
1669+
assert reader.schema == requested_schema
1670+
1671+
batches = list(reader)
1672+
1673+
assert len(batches) == 1
1674+
expected_batch = pa.record_batch(
1675+
[pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"]
1676+
)
1677+
assert batches[0].equals(expected_batch)
1678+
1679+
15851680
def test_to_pylist(df):
15861681
# Convert datafusion dataframe to Python list
15871682
pylist = df.to_pylist()
@@ -2666,6 +2761,110 @@ def trigger_interrupt():
26662761
interrupt_thread.join(timeout=1.0)
26672762

26682763

2764+
def test_arrow_c_stream_interrupted():
2765+
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2766+
2767+
Similar to ``test_collect_interrupted`` this test issues a long running
2768+
query, but consumes the results via ``__arrow_c_stream__``. It then raises
2769+
``KeyboardInterrupt`` in the main thread and verifies that the stream
2770+
iteration stops promptly with the appropriate exception.
2771+
"""
2772+
2773+
ctx = SessionContext()
2774+
2775+
batches = []
2776+
for i in range(10):
2777+
batch = pa.RecordBatch.from_arrays(
2778+
[
2779+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2780+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2781+
],
2782+
names=["a", "b"],
2783+
)
2784+
batches.append(batch)
2785+
2786+
ctx.register_record_batches("t1", [batches])
2787+
ctx.register_record_batches("t2", [batches])
2788+
2789+
df = ctx.sql(
2790+
"""
2791+
WITH t1_expanded AS (
2792+
SELECT
2793+
a,
2794+
b,
2795+
CAST(a AS DOUBLE) / 1.5 AS c,
2796+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2797+
FROM t1
2798+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2799+
),
2800+
t2_expanded AS (
2801+
SELECT
2802+
a,
2803+
b,
2804+
CAST(a AS DOUBLE) * 2.5 AS e,
2805+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2806+
FROM t2
2807+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2808+
)
2809+
SELECT
2810+
t1.a, t1.b, t1.c, t1.d,
2811+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2812+
FROM t1_expanded t1
2813+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2814+
WHERE t1.a > 100 AND t2.a > 100
2815+
"""
2816+
)
2817+
2818+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
2819+
2820+
interrupted = False
2821+
interrupt_error = None
2822+
query_started = threading.Event()
2823+
max_wait_time = 5.0
2824+
2825+
def trigger_interrupt():
2826+
start_time = time.time()
2827+
while not query_started.is_set():
2828+
time.sleep(0.1)
2829+
if time.time() - start_time > max_wait_time:
2830+
msg = f"Query did not start within {max_wait_time} seconds"
2831+
raise RuntimeError(msg)
2832+
2833+
thread_id = threading.main_thread().ident
2834+
if thread_id is None:
2835+
msg = "Cannot get main thread ID"
2836+
raise RuntimeError(msg)
2837+
2838+
exception = ctypes.py_object(KeyboardInterrupt)
2839+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2840+
ctypes.c_long(thread_id), exception
2841+
)
2842+
if res != 1:
2843+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2844+
ctypes.c_long(thread_id), ctypes.py_object(0)
2845+
)
2846+
msg = "Failed to raise KeyboardInterrupt in main thread"
2847+
raise RuntimeError(msg)
2848+
2849+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2850+
interrupt_thread.daemon = True
2851+
interrupt_thread.start()
2852+
2853+
try:
2854+
query_started.set()
2855+
# consume the reader which should block and be interrupted
2856+
reader.read_all()
2857+
except KeyboardInterrupt:
2858+
interrupted = True
2859+
except Exception as e: # pragma: no cover - unexpected errors
2860+
interrupt_error = e
2861+
2862+
if not interrupted:
2863+
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
2864+
2865+
interrupt_thread.join(timeout=1.0)
2866+
2867+
26692868
def test_show_select_where_no_rows(capsys) -> None:
26702869
ctx = SessionContext()
26712870
df = ctx.sql("SELECT 1 WHERE 1=0")

0 commit comments

Comments
 (0)