Skip to content

Commit b01d166

Browse files
committed
Revert "UNPICK"
This reverts commit 4380bcf.
1 parent 4380bcf commit b01d166

File tree

11 files changed

+547
-55
lines changed

11 files changed

+547
-55
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,39 @@ To materialize the results of your DataFrame operations:
145145
146146
# Display results
147147
df.show() # Print tabular format to console
148-
148+
149149
# Count rows
150150
count = df.count()
151151
152+
PyArrow Streaming
153+
-----------------
154+
155+
DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling
156+
zero-copy streaming into libraries like `PyArrow <https://arrow.apache.org/>`_.
157+
Earlier versions eagerly converted the entire DataFrame when exporting to
158+
PyArrow, which could exhaust memory on large datasets. With streaming, batches
159+
are produced lazily so you can process arbitrarily large results without
160+
out-of-memory errors.
161+
162+
.. code-block:: python
163+
164+
import pyarrow as pa
165+
166+
# Create a PyArrow RecordBatchReader without materializing all batches
167+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
168+
for batch in reader:
169+
... # process each batch as it is produced
170+
171+
DataFrames are also iterable, yielding :class:`pyarrow.RecordBatch` objects
172+
lazily so you can loop over results directly:
173+
174+
.. code-block:: python
175+
176+
for batch in df:
177+
... # process each batch as it is produced
178+
179+
See :doc:`../io/arrow` for additional details on the Arrow interface.
180+
152181
HTML Rendering
153182
--------------
154183

python/datafusion/dataframe.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TYPE_CHECKING,
2727
Any,
2828
Iterable,
29+
Iterator,
2930
Literal,
3031
Optional,
3132
Union,
@@ -289,6 +290,9 @@ def __init__(
289290
class DataFrame:
290291
"""Two dimensional table representation of data.
291292
293+
DataFrame objects are iterable; iterating over a DataFrame yields
294+
:class:`pyarrow.RecordBatch` instances lazily.
295+
292296
See :ref:`user_guide_concepts` in the online documentation for more information.
293297
"""
294298

@@ -1098,21 +1102,42 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
10981102
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))
10991103

11001104
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
1101-
"""Export an Arrow PyCapsule Stream.
1105+
"""Export the DataFrame as an Arrow C Stream.
11021106
1103-
This will execute and collect the DataFrame. We will attempt to respect the
1104-
requested schema, but only trivial transformations will be applied such as only
1105-
returning the fields listed in the requested schema if their data types match
1106-
those in the DataFrame.
1107+
The DataFrame is executed using DataFusion's streaming APIs and exposed via
1108+
Arrow's C Stream interface. Record batches are produced incrementally, so the
1109+
full result set is never materialized in memory. When ``requested_schema`` is
1110+
provided, only straightforward projections such as column selection or
1111+
reordering are applied.
11071112
11081113
Args:
11091114
requested_schema: Attempt to provide the DataFrame using this schema.
11101115
11111116
Returns:
1112-
Arrow PyCapsule object.
1117+
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11131118
"""
1119+
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1120+
# ``execute_stream_partitioned`` under the hood to stream batches while
1121+
# preserving the original partition order.
11141122
return self.df.__arrow_c_stream__(requested_schema)
11151123

1124+
def __iter__(self) -> Iterator[pa.RecordBatch]:
1125+
"""Yield record batches from the DataFrame without materializing results.
1126+
1127+
This implementation streams record batches via the Arrow C Stream
1128+
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
1129+
consume results lazily. The DataFrame is executed using DataFusion's
1130+
partitioned streaming APIs so ``collect`` is never invoked and batch
1131+
order across partitions is preserved.
1132+
"""
1133+
from contextlib import closing
1134+
1135+
import pyarrow as pa
1136+
1137+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
1138+
with closing(reader):
1139+
yield from reader
1140+
11161141
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11171142
"""Apply a function to the current DataFrame which returns another DataFrame.
11181143

python/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import SessionContext
20+
from datafusion import DataFrame, SessionContext
2121
from pyarrow.csv import write_csv
2222

2323

@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
4949
delimiter=",",
5050
schema_infer_max_records=10,
5151
)
52+
53+
54+
@pytest.fixture
55+
def fail_collect(monkeypatch):
56+
def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path
57+
msg = "collect should not be called"
58+
raise AssertionError(msg)
59+
60+
monkeypatch.setattr(DataFrame, "collect", _fail_collect)

python/tests/test_dataframe.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,61 @@ def test_empty_to_arrow_table(df):
15821582
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831583

15841584

1585+
def test_iter_batches_dataframe(fail_collect):
1586+
ctx = SessionContext()
1587+
1588+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1589+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1590+
df = ctx.create_dataframe([[batch1], [batch2]])
1591+
1592+
expected = [batch1, batch2]
1593+
for got, exp in zip(df, expected):
1594+
assert got.equals(exp)
1595+
1596+
1597+
def test_arrow_c_stream_to_table(fail_collect):
1598+
ctx = SessionContext()
1599+
1600+
# Create a DataFrame with two separate record batches
1601+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1602+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1603+
df = ctx.create_dataframe([[batch1], [batch2]])
1604+
1605+
table = pa.Table.from_batches(df)
1606+
batches = table.to_batches()
1607+
1608+
assert len(batches) == 2
1609+
assert batches[0].equals(batch1)
1610+
assert batches[1].equals(batch2)
1611+
assert table.schema == df.schema()
1612+
assert table.column("a").num_chunks == 2
1613+
1614+
1615+
def test_arrow_c_stream_order():
1616+
ctx = SessionContext()
1617+
1618+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1619+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1620+
1621+
df = ctx.create_dataframe([[batch1, batch2]])
1622+
1623+
table = pa.Table.from_batches(df)
1624+
expected = pa.Table.from_batches([batch1, batch2])
1625+
1626+
assert table.equals(expected)
1627+
col = table.column("a")
1628+
assert col.chunk(0)[0].as_py() == 1
1629+
assert col.chunk(1)[0].as_py() == 2
1630+
1631+
1632+
def test_arrow_c_stream_reader(df):
1633+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
1634+
assert isinstance(reader, pa.RecordBatchReader)
1635+
table = pa.Table.from_batches(reader)
1636+
expected = pa.Table.from_batches(df.collect())
1637+
assert table.equals(expected)
1638+
1639+
15851640
def test_to_pylist(df):
15861641
# Convert datafusion dataframe to Python list
15871642
pylist = df.to_pylist()
@@ -2666,6 +2721,110 @@ def trigger_interrupt():
26662721
interrupt_thread.join(timeout=1.0)
26672722

26682723

2724+
def test_arrow_c_stream_interrupted():
2725+
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2726+
2727+
Similar to ``test_collect_interrupted`` this test issues a long running
2728+
query, but consumes the results via ``__arrow_c_stream__``. It then raises
2729+
``KeyboardInterrupt`` in the main thread and verifies that the stream
2730+
iteration stops promptly with the appropriate exception.
2731+
"""
2732+
2733+
ctx = SessionContext()
2734+
2735+
batches = []
2736+
for i in range(10):
2737+
batch = pa.RecordBatch.from_arrays(
2738+
[
2739+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2740+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2741+
],
2742+
names=["a", "b"],
2743+
)
2744+
batches.append(batch)
2745+
2746+
ctx.register_record_batches("t1", [batches])
2747+
ctx.register_record_batches("t2", [batches])
2748+
2749+
df = ctx.sql(
2750+
"""
2751+
WITH t1_expanded AS (
2752+
SELECT
2753+
a,
2754+
b,
2755+
CAST(a AS DOUBLE) / 1.5 AS c,
2756+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2757+
FROM t1
2758+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2759+
),
2760+
t2_expanded AS (
2761+
SELECT
2762+
a,
2763+
b,
2764+
CAST(a AS DOUBLE) * 2.5 AS e,
2765+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2766+
FROM t2
2767+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2768+
)
2769+
SELECT
2770+
t1.a, t1.b, t1.c, t1.d,
2771+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2772+
FROM t1_expanded t1
2773+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2774+
WHERE t1.a > 100 AND t2.a > 100
2775+
"""
2776+
)
2777+
2778+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
2779+
2780+
interrupted = False
2781+
interrupt_error = None
2782+
query_started = threading.Event()
2783+
max_wait_time = 5.0
2784+
2785+
def trigger_interrupt():
2786+
start_time = time.time()
2787+
while not query_started.is_set():
2788+
time.sleep(0.1)
2789+
if time.time() - start_time > max_wait_time:
2790+
msg = f"Query did not start within {max_wait_time} seconds"
2791+
raise RuntimeError(msg)
2792+
2793+
thread_id = threading.main_thread().ident
2794+
if thread_id is None:
2795+
msg = "Cannot get main thread ID"
2796+
raise RuntimeError(msg)
2797+
2798+
exception = ctypes.py_object(KeyboardInterrupt)
2799+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2800+
ctypes.c_long(thread_id), exception
2801+
)
2802+
if res != 1:
2803+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2804+
ctypes.c_long(thread_id), ctypes.py_object(0)
2805+
)
2806+
msg = "Failed to raise KeyboardInterrupt in main thread"
2807+
raise RuntimeError(msg)
2808+
2809+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2810+
interrupt_thread.daemon = True
2811+
interrupt_thread.start()
2812+
2813+
try:
2814+
query_started.set()
2815+
# consume the reader which should block and be interrupted
2816+
reader.read_all()
2817+
except KeyboardInterrupt:
2818+
interrupted = True
2819+
except Exception as e: # pragma: no cover - unexpected errors
2820+
interrupt_error = e
2821+
2822+
if not interrupted:
2823+
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
2824+
2825+
interrupt_thread.join(timeout=1.0)
2826+
2827+
26692828
def test_show_select_where_no_rows(capsys) -> None:
26702829
ctx = SessionContext()
26712830
df = ctx.sql("SELECT 1 WHERE 1=0")
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
20+
21+
def test_iter_releases_reader(monkeypatch, ctx):
22+
batches = [
23+
pa.RecordBatch.from_pydict({"a": [1]}),
24+
pa.RecordBatch.from_pydict({"a": [2]}),
25+
]
26+
27+
class DummyReader:
28+
def __init__(self, batches):
29+
self._iter = iter(batches)
30+
self.closed = False
31+
32+
def __iter__(self):
33+
return self
34+
35+
def __next__(self):
36+
return next(self._iter)
37+
38+
def close(self):
39+
self.closed = True
40+
41+
dummy_reader = DummyReader(batches)
42+
43+
class FakeRecordBatchReader:
44+
@staticmethod
45+
def _import_from_c_capsule(*_args, **_kwargs):
46+
return dummy_reader
47+
48+
monkeypatch.setattr(pa, "RecordBatchReader", FakeRecordBatchReader)
49+
50+
df = ctx.from_pydict({"a": [1, 2]})
51+
52+
for _ in df:
53+
break
54+
55+
assert dummy_reader.closed

0 commit comments

Comments
 (0)