Skip to content

Commit ca13f48

Browse files
committed
test: ensure RecordBatch instances in DataFrame stream tests
1 parent 54c2f59 commit ca13f48

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

python/tests/test_dataframe.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DataFrame,
3030
ParquetColumnOptions,
3131
ParquetWriterOptions,
32+
RecordBatch,
3233
SessionContext,
3334
WindowFrame,
3435
column,
@@ -1504,7 +1505,8 @@ def test_to_arrow_table(df):
15041505

15051506
def test_execute_stream(df):
15061507
stream = df.execute_stream()
1507-
assert all(batch is not None for batch in stream)
1508+
batches = list(stream)
1509+
assert all(isinstance(batch, RecordBatch) for batch in batches)
15081510
assert not list(stream) # after one iteration the generator must be exhausted
15091511

15101512

@@ -1513,7 +1515,7 @@ async def test_execute_stream_async(df):
15131515
stream = df.execute_stream()
15141516
batches = [batch async for batch in stream]
15151517

1516-
assert all(batch is not None for batch in batches)
1518+
assert all(isinstance(batch, RecordBatch) for batch in batches)
15171519

15181520
# After consuming all batches, the stream should be exhausted
15191521
remaining_batches = [batch async for batch in stream]
@@ -1557,10 +1559,10 @@ async def test_execute_stream_to_arrow_table_async(df, schema):
15571559

15581560
def test_execute_stream_partitioned(df):
15591561
streams = df.execute_stream_partitioned()
1560-
assert all(batch is not None for stream in streams for batch in stream)
1561-
assert all(
1562-
not list(stream) for stream in streams
1563-
) # after one iteration all generators must be exhausted
1562+
for stream in streams:
1563+
batches = list(stream)
1564+
assert all(isinstance(batch, RecordBatch) for batch in batches)
1565+
assert not list(stream)
15641566

15651567

15661568
@pytest.mark.asyncio
@@ -1569,7 +1571,7 @@ async def test_execute_stream_partitioned_async(df):
15691571

15701572
for stream in streams:
15711573
batches = [batch async for batch in stream]
1572-
assert all(batch is not None for batch in batches)
1574+
assert all(isinstance(batch, RecordBatch) for batch in batches)
15731575

15741576
# Ensure the stream is exhausted after iteration
15751577
remaining_batches = [batch async for batch in stream]
@@ -1593,7 +1595,17 @@ def test_iter_batches_dataframe(fail_collect):
15931595

15941596
expected = [batch1, batch2]
15951597
for got, exp in zip(df, expected):
1596-
assert got.equals(exp)
1598+
assert isinstance(got, RecordBatch)
1599+
assert got.to_pyarrow().equals(exp)
1600+
1601+
1602+
def test_table_from_batches_dataframe(df, fail_collect):
1603+
table = pa.Table.from_batches(df)
1604+
assert table.shape == (3, 3)
1605+
assert set(table.column_names) == {"a", "b", "c"}
1606+
1607+
for batch in df:
1608+
assert isinstance(batch, RecordBatch)
15971609

15981610

15991611
def test_arrow_c_stream_to_table_and_reader(fail_collect):
@@ -1855,8 +1867,6 @@ def test_write_parquet_with_options_default_compression(df, tmp_path):
18551867
["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"],
18561868
)
18571869
def test_write_parquet_with_options_compression(df, tmp_path, compression):
1858-
import re
1859-
18601870
path = tmp_path
18611871
df.write_parquet_with_options(
18621872
str(path), ParquetWriterOptions(compression=compression)

python/tests/test_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pyarrow as pa
2121
import pytest
22-
from datafusion import column
22+
from datafusion import RecordBatch, column
2323
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2424

2525
from .utils import range_table
@@ -131,3 +131,6 @@ def test_table_from_batches_stream(ctx, fail_collect):
131131
table = pa.Table.from_batches(df)
132132
assert table.shape == (10, 1)
133133
assert table.column_names == ["value"]
134+
135+
for batch in df:
136+
assert isinstance(batch, RecordBatch)

0 commit comments

Comments
 (0)