Skip to content

Commit 032fab2

Browse files
committed
Revert "UNPICK"
This reverts commit ae3c67e.
1 parent ae3c67e commit 032fab2

File tree

10 files changed

+511
-56
lines changed

10 files changed

+511
-56
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,39 @@ To materialize the results of your DataFrame operations:
145145
146146
# Display results
147147
df.show() # Print tabular format to console
148-
148+
149149
# Count rows
150150
count = df.count()
151151
152+
PyArrow Streaming
153+
-----------------
154+
155+
DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling
156+
zero-copy streaming into libraries like `PyArrow <https://arrow.apache.org/>`_.
157+
Earlier versions eagerly converted the entire DataFrame when exporting to
158+
PyArrow, which could exhaust memory on large datasets. With streaming, batches
159+
are produced lazily so you can process arbitrarily large results without
160+
out-of-memory errors.
161+
162+
.. code-block:: python
163+
164+
import pyarrow as pa
165+
166+
# Create a PyArrow RecordBatchReader without materializing all batches
167+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
168+
for batch in reader:
169+
... # process each batch as it is produced
170+
171+
DataFrames are also iterable, yielding :class:`pyarrow.RecordBatch` objects
172+
lazily so you can loop over results directly:
173+
174+
.. code-block:: python
175+
176+
for batch in df:
177+
... # process each batch as it is produced
178+
179+
See :doc:`../io/arrow` for additional details on the Arrow interface.
180+
152181
HTML Rendering
153182
--------------
154183

python/datafusion/_testing.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Testing-only helpers for datafusion-python.
2+
3+
This module contains utilities used by the test-suite that should not be
4+
exposed as part of the public API. Keep the implementation minimal and
5+
documented so reviewers can easily see it's test-only.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import TYPE_CHECKING
11+
12+
from .context import SessionContext
13+
14+
if TYPE_CHECKING:
15+
from datafusion import DataFrame
16+
17+
18+
def range_table(
19+
ctx: SessionContext,
20+
start: int,
21+
stop: int | None = None,
22+
step: int = 1,
23+
partitions: int | None = None,
24+
) -> DataFrame:
25+
"""Create a DataFrame containing a sequence of numbers using SQL RANGE.
26+
27+
This mirrors the previous ``SessionContext.range`` convenience method but
28+
lives in a testing-only module so it doesn't expand the public surface.
29+
30+
Args:
31+
ctx: SessionContext instance to run the SQL against.
32+
start: Starting value for the sequence or exclusive stop when ``stop``
33+
is ``None``.
34+
stop: Exclusive upper bound of the sequence.
35+
step: Increment between successive values.
36+
partitions: Optional number of partitions for the generated data.
37+
38+
Returns:
39+
DataFrame produced by the range table function.
40+
"""
41+
if stop is None:
42+
start, stop = 0, start
43+
44+
parts = f", {int(partitions)}" if partitions is not None else ""
45+
sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})" # noqa: S608
46+
return ctx.sql(sql)

python/datafusion/dataframe.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TYPE_CHECKING,
2727
Any,
2828
Iterable,
29+
Iterator,
2930
Literal,
3031
Optional,
3132
Union,
@@ -289,6 +290,9 @@ def __init__(
289290
class DataFrame:
290291
"""Two dimensional table representation of data.
291292
293+
DataFrame objects are iterable; iterating over a DataFrame yields
294+
:class:`pyarrow.RecordBatch` instances lazily.
295+
292296
See :ref:`user_guide_concepts` in the online documentation for more information.
293297
"""
294298

@@ -1098,21 +1102,39 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
10981102
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))
10991103

11001104
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
1101-
"""Export an Arrow PyCapsule Stream.
1105+
"""Export the DataFrame as an Arrow C Stream.
11021106
1103-
This will execute and collect the DataFrame. We will attempt to respect the
1104-
requested schema, but only trivial transformations will be applied such as only
1105-
returning the fields listed in the requested schema if their data types match
1106-
those in the DataFrame.
1107+
The DataFrame is executed using DataFusion's streaming APIs and exposed via
1108+
Arrow's C Stream interface. Record batches are produced incrementally, so the
1109+
full result set is never materialized in memory. When ``requested_schema`` is
1110+
provided, only straightforward projections such as column selection or
1111+
reordering are applied.
11071112
11081113
Args:
11091114
requested_schema: Attempt to provide the DataFrame using this schema.
11101115
11111116
Returns:
1112-
Arrow PyCapsule object.
1117+
Arrow PyCapsule object representing an ``ArrowArrayStream``.
11131118
"""
1119+
# ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
1120+
# ``execute_stream_partitioned`` under the hood to stream batches while
1121+
# preserving the original partition order.
11141122
return self.df.__arrow_c_stream__(requested_schema)
11151123

1124+
def __iter__(self) -> Iterator[pa.RecordBatch]:
1125+
"""Yield record batches from the DataFrame without materializing results.
1126+
1127+
This implementation streams record batches via the Arrow C Stream
1128+
interface, allowing callers such as :func:`pyarrow.Table.from_batches` to
1129+
consume results lazily. The DataFrame is executed using DataFusion's
1130+
partitioned streaming APIs so ``collect`` is never invoked and batch
1131+
order across partitions is preserved.
1132+
"""
1133+
import pyarrow as pa
1134+
1135+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
1136+
yield from reader
1137+
11161138
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
11171139
"""Apply a function to the current DataFrame which returns another DataFrame.
11181140

python/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import SessionContext
20+
from datafusion import DataFrame, SessionContext
2121
from pyarrow.csv import write_csv
2222

2323

@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
4949
delimiter=",",
5050
schema_infer_max_records=10,
5151
)
52+
53+
54+
@pytest.fixture
55+
def fail_collect(monkeypatch):
56+
def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path
57+
msg = "collect should not be called"
58+
raise AssertionError(msg)
59+
60+
monkeypatch.setattr(DataFrame, "collect", _fail_collect)

python/tests/test_dataframe.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,61 @@ def test_empty_to_arrow_table(df):
15821582
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831583

15841584

1585+
def test_iter_batches_dataframe(fail_collect):
1586+
ctx = SessionContext()
1587+
1588+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1589+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1590+
df = ctx.create_dataframe([[batch1], [batch2]])
1591+
1592+
expected = [batch1, batch2]
1593+
for got, exp in zip(df, expected):
1594+
assert got.equals(exp)
1595+
1596+
1597+
def test_arrow_c_stream_to_table(fail_collect):
1598+
ctx = SessionContext()
1599+
1600+
# Create a DataFrame with two separate record batches
1601+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1602+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1603+
df = ctx.create_dataframe([[batch1], [batch2]])
1604+
1605+
table = pa.Table.from_batches(df)
1606+
batches = table.to_batches()
1607+
1608+
assert len(batches) == 2
1609+
assert batches[0].equals(batch1)
1610+
assert batches[1].equals(batch2)
1611+
assert table.schema == df.schema()
1612+
assert table.column("a").num_chunks == 2
1613+
1614+
1615+
def test_arrow_c_stream_order():
1616+
ctx = SessionContext()
1617+
1618+
batch1 = pa.record_batch([pa.array([1])], names=["a"])
1619+
batch2 = pa.record_batch([pa.array([2])], names=["a"])
1620+
1621+
df = ctx.create_dataframe([[batch1, batch2]])
1622+
1623+
table = pa.Table.from_batches(df)
1624+
expected = pa.Table.from_batches([batch1, batch2])
1625+
1626+
assert table.equals(expected)
1627+
col = table.column("a")
1628+
assert col.chunk(0)[0].as_py() == 1
1629+
assert col.chunk(1)[0].as_py() == 2
1630+
1631+
1632+
def test_arrow_c_stream_reader(df):
1633+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
1634+
assert isinstance(reader, pa.RecordBatchReader)
1635+
table = pa.Table.from_batches(reader)
1636+
expected = pa.Table.from_batches(df.collect())
1637+
assert table.equals(expected)
1638+
1639+
15851640
def test_to_pylist(df):
15861641
# Convert datafusion dataframe to Python list
15871642
pylist = df.to_pylist()
@@ -2666,6 +2721,110 @@ def trigger_interrupt():
26662721
interrupt_thread.join(timeout=1.0)
26672722

26682723

2724+
def test_arrow_c_stream_interrupted():
2725+
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2726+
2727+
Similar to ``test_collect_interrupted`` this test issues a long running
2728+
query, but consumes the results via ``__arrow_c_stream__``. It then raises
2729+
``KeyboardInterrupt`` in the main thread and verifies that the stream
2730+
iteration stops promptly with the appropriate exception.
2731+
"""
2732+
2733+
ctx = SessionContext()
2734+
2735+
batches = []
2736+
for i in range(10):
2737+
batch = pa.RecordBatch.from_arrays(
2738+
[
2739+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2740+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2741+
],
2742+
names=["a", "b"],
2743+
)
2744+
batches.append(batch)
2745+
2746+
ctx.register_record_batches("t1", [batches])
2747+
ctx.register_record_batches("t2", [batches])
2748+
2749+
df = ctx.sql(
2750+
"""
2751+
WITH t1_expanded AS (
2752+
SELECT
2753+
a,
2754+
b,
2755+
CAST(a AS DOUBLE) / 1.5 AS c,
2756+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2757+
FROM t1
2758+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2759+
),
2760+
t2_expanded AS (
2761+
SELECT
2762+
a,
2763+
b,
2764+
CAST(a AS DOUBLE) * 2.5 AS e,
2765+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2766+
FROM t2
2767+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2768+
)
2769+
SELECT
2770+
t1.a, t1.b, t1.c, t1.d,
2771+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2772+
FROM t1_expanded t1
2773+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2774+
WHERE t1.a > 100 AND t2.a > 100
2775+
"""
2776+
)
2777+
2778+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
2779+
2780+
interrupted = False
2781+
interrupt_error = None
2782+
query_started = threading.Event()
2783+
max_wait_time = 5.0
2784+
2785+
def trigger_interrupt():
2786+
start_time = time.time()
2787+
while not query_started.is_set():
2788+
time.sleep(0.1)
2789+
if time.time() - start_time > max_wait_time:
2790+
msg = f"Query did not start within {max_wait_time} seconds"
2791+
raise RuntimeError(msg)
2792+
2793+
thread_id = threading.main_thread().ident
2794+
if thread_id is None:
2795+
msg = "Cannot get main thread ID"
2796+
raise RuntimeError(msg)
2797+
2798+
exception = ctypes.py_object(KeyboardInterrupt)
2799+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2800+
ctypes.c_long(thread_id), exception
2801+
)
2802+
if res != 1:
2803+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2804+
ctypes.c_long(thread_id), ctypes.py_object(0)
2805+
)
2806+
msg = "Failed to raise KeyboardInterrupt in main thread"
2807+
raise RuntimeError(msg)
2808+
2809+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2810+
interrupt_thread.daemon = True
2811+
interrupt_thread.start()
2812+
2813+
try:
2814+
query_started.set()
2815+
# consume the reader which should block and be interrupted
2816+
reader.read_all()
2817+
except KeyboardInterrupt:
2818+
interrupted = True
2819+
except Exception as e: # pragma: no cover - unexpected errors
2820+
interrupt_error = e
2821+
2822+
if not interrupted:
2823+
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
2824+
2825+
interrupt_thread.join(timeout=1.0)
2826+
2827+
26692828
def test_show_select_where_no_rows(capsys) -> None:
26702829
ctx = SessionContext()
26712830
df = ctx.sql("SELECT 1 WHERE 1=0")

python/tests/test_io.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from pathlib import Path
1818

1919
import pyarrow as pa
20+
import pytest
2021
from datafusion import column
22+
from datafusion._testing import range_table
2123
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2224

2325

@@ -92,3 +94,38 @@ def test_read_avro():
9294
path = Path.cwd() / "testing/data/avro/alltypes_plain.avro"
9395
avro_df = read_avro(path=path)
9496
assert avro_df is not None
97+
98+
99+
def test_arrow_c_stream_large_dataset(ctx):
100+
"""DataFrame.__arrow_c_stream__ yields batches incrementally.
101+
102+
This test constructs a DataFrame that would be far larger than available
103+
memory if materialized. The ``__arrow_c_stream__`` method should expose a
104+
stream of record batches without collecting the full dataset, so reading a
105+
handful of batches should not exhaust process memory.
106+
"""
107+
# Create a very large DataFrame using range; this would be terabytes if collected
108+
df = range_table(ctx, 0, 1 << 40)
109+
110+
reader = pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())
111+
112+
# Track RSS before consuming batches
113+
psutil = pytest.importorskip("psutil")
114+
process = psutil.Process()
115+
start_rss = process.memory_info().rss
116+
117+
for _ in range(5):
118+
batch = reader.read_next_batch()
119+
assert batch is not None
120+
assert len(batch) > 0
121+
current_rss = process.memory_info().rss
122+
# Ensure memory usage hasn't grown substantially (>50MB)
123+
assert current_rss - start_rss < 50 * 1024 * 1024
124+
125+
126+
def test_table_from_batches_stream(ctx, fail_collect):
127+
df = range_table(ctx, 0, 10)
128+
129+
table = pa.Table.from_batches(df)
130+
assert table.shape == (10, 1)
131+
assert table.column_names == ["value"]

0 commit comments

Comments
 (0)