Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,7 @@ def _task_to_record_batches(
partition_spec: PartitionSpec | None = None,
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
downcast_ns_timestamp_to_us: bool | None = None,
batch_size: int | None = None,
) -> Iterator[pa.RecordBatch]:
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with io.new_input(task.file.file_path).open() as fin:
Expand Down Expand Up @@ -1619,6 +1620,7 @@ def _task_to_record_batches(
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
columns=[col.name for col in file_project_schema.columns],
batch_size=batch_size,
)

next_index = 0
Expand Down Expand Up @@ -1802,8 +1804,30 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
# This break will also cancel all running tasks in the executor
break

def to_record_batch_stream(self, tasks: Iterable[FileScanTask], batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch] in a streaming fashion.

Files are read sequentially and batches are yielded one at a time
without materializing all batches in memory. Use this when memory
efficiency is more important than throughput.

Args:
tasks: FileScanTasks representing the data files and delete files to read from.
batch_size: Maximum number of rows per RecordBatch. If None,
uses PyArrow's default (131,072 rows).

Yields:
pa.RecordBatch: Record batches from the scan, one at a time.
"""
tasks = list(tasks) if not isinstance(tasks, list) else tasks
deletes_per_file = _read_all_delete_files(self._io, tasks)
yield from self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file, batch_size)

def _record_batches_from_scan_tasks_and_deletes(
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]]
self,
tasks: Iterable[FileScanTask],
deletes_per_file: dict[str, list[ChunkedArray]],
batch_size: int | None = None,
) -> Iterator[pa.RecordBatch]:
total_row_count = 0
for task in tasks:
Expand All @@ -1822,6 +1846,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._table_metadata.specs().get(task.file.spec_id),
self._table_metadata.format_version,
self._downcast_ns_timestamp_to_us,
batch_size,
)
for batch in batches:
if self._limit is not None:
Expand Down
20 changes: 20 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,26 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
batches,
).cast(target_schema)

def to_record_batches(self, batch_size: int | None = None) -> Iterator[pa.RecordBatch]:
"""Read record batches in a streaming fashion from this DataScan.

Files are read sequentially and batches are yielded one at a time
without materializing all batches in memory. Use this when memory
efficiency is more important than throughput.

Args:
batch_size: Maximum number of rows per RecordBatch. If None,
uses PyArrow's default (131,072 rows).

Yields:
pa.RecordBatch: Record batches from the scan, one at a time.
"""
from pyiceberg.io.pyarrow import ArrowScan

yield from ArrowScan(
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
).to_record_batch_stream(self.plan_files(), batch_size)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
"""Read a Pandas DataFrame eagerly from this Iceberg table.

Expand Down
48 changes: 48 additions & 0 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,3 +1272,51 @@ def test_scan_source_field_missing_in_spec(catalog: Catalog, spark: SparkSession

table = catalog.load_table(identifier)
assert len(list(table.scan().plan_files())) == 3


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_datascan_to_record_batches(catalog: Catalog) -> None:
table = create_table(catalog)

arrow_table = pa.Table.from_pydict(
{
"str": ["a", "b", "c"],
"int": [1, 2, 3],
},
schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]),
)
table.append(arrow_table)

scan = table.scan()
streaming_batches = list(scan.to_record_batches())
streaming_result = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive")

eager_result = scan.to_arrow()

assert streaming_result.num_rows == eager_result.num_rows
assert streaming_result.column_names == eager_result.column_names
assert streaming_result.sort_by("int").equals(eager_result.sort_by("int"))


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")])
def test_datascan_to_record_batches_with_batch_size(catalog: Catalog) -> None:
table = create_table(catalog)

arrow_table = pa.Table.from_pydict(
{
"str": [f"val_{i}" for i in range(100)],
"int": list(range(100)),
},
schema=pa.schema([pa.field("str", pa.large_string()), pa.field("int", pa.int32())]),
)
table.append(arrow_table)

scan = table.scan()
batches = list(scan.to_record_batches(batch_size=10))

total_rows = sum(len(b) for b in batches)
assert total_rows == 100
for batch in batches:
assert len(batch) <= 10
207 changes: 207 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4884,3 +4884,210 @@ def test_partition_column_projection_with_schema_evolution(catalog: InMemoryCata
result_sorted = result.sort_by("name")
assert result_sorted["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David"]
assert result_sorted["new_column"].to_pylist() == [None, None, "new1", "new2"]


def test_task_to_record_batches_with_batch_size(tmpdir: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})

# Create a parquet file with 1000 rows
table = pa.Table.from_arrays([pa.array(list(range(1000)))], schema=pyarrow_schema)
data_file = _write_table_to_data_file(f"{tmpdir}/batch_size_test.parquet", pyarrow_schema, table)
data_file.spec_id = 0

task = FileScanTask(data_file=data_file)

batches = list(
_task_to_record_batches(
PyArrowFileIO(),
task,
bound_row_filter=AlwaysTrue(),
projected_schema=schema,
table_schema=schema,
projected_field_ids={1},
positional_deletes=None,
case_sensitive=True,
batch_size=100,
)
)

total_rows = sum(len(b) for b in batches)
assert total_rows == 1000
for batch in batches:
assert len(batch) <= 100


def test_to_record_batch_stream_basic(tmpdir: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})

table = pa.Table.from_arrays([pa.array(list(range(100)))], schema=pyarrow_schema)
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_basic.parquet", pyarrow_schema, table)
data_file.spec_id = 0

task = FileScanTask(data_file=data_file)

scan = ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
last_column_id=1,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
),
io=PyArrowFileIO(),
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)

result = scan.to_record_batch_stream([task])
# Should be a generator/iterator, not a list
import types

assert isinstance(result, types.GeneratorType)

batches = list(result)
total_rows = sum(len(b) for b in batches)
assert total_rows == 100


def test_to_record_batch_stream_with_batch_size(tmpdir: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})

table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema)
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_batch_size.parquet", pyarrow_schema, table)
data_file.spec_id = 0

task = FileScanTask(data_file=data_file)

scan = ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
last_column_id=1,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
),
io=PyArrowFileIO(),
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)

batches = list(scan.to_record_batch_stream([task], batch_size=50))

total_rows = sum(len(b) for b in batches)
assert total_rows == 500
for batch in batches:
assert len(batch) <= 50


def test_to_record_batch_stream_with_limit(tmpdir: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})

table = pa.Table.from_arrays([pa.array(list(range(500)))], schema=pyarrow_schema)
data_file = _write_table_to_data_file(f"{tmpdir}/streaming_limit.parquet", pyarrow_schema, table)
data_file.spec_id = 0

task = FileScanTask(data_file=data_file)

scan = ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
last_column_id=1,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
),
io=PyArrowFileIO(),
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
limit=100,
)

batches = list(scan.to_record_batch_stream([task]))

total_rows = sum(len(b) for b in batches)
assert total_rows == 100


def test_to_record_batch_stream_with_deletes(
deletes_file: str, request: pytest.FixtureRequest, table_schema_simple: Schema
) -> None:
file_format = FileFormat.PARQUET if deletes_file.endswith(".parquet") else FileFormat.ORC

if file_format == FileFormat.PARQUET:
example_task = request.getfixturevalue("example_task")
else:
example_task = request.getfixturevalue("example_task_orc")

example_task_with_delete = FileScanTask(
data_file=example_task.file,
delete_files={
DataFile.from_args(
content=DataFileContent.POSITION_DELETES,
file_path=deletes_file,
file_format=file_format,
)
},
)

metadata_location = "file://a/b/c.json"
scan = ArrowScan(
table_metadata=TableMetadataV2(
location=metadata_location,
last_column_id=1,
format_version=2,
current_schema_id=1,
schemas=[table_schema_simple],
partition_specs=[PartitionSpec()],
),
io=load_file_io(),
projected_schema=table_schema_simple,
row_filter=AlwaysTrue(),
)

# Compare streaming path to table path
streaming_batches = list(scan.to_record_batch_stream([example_task_with_delete]))
streaming_table = pa.concat_tables([pa.Table.from_batches([b]) for b in streaming_batches], promote_options="permissive")
eager_table = scan.to_table(tasks=[example_task_with_delete])

assert streaming_table.num_rows == eager_table.num_rows
assert streaming_table.column_names == eager_table.column_names


def test_to_record_batch_stream_multiple_files(tmpdir: str) -> None:
schema = Schema(NestedField(1, "id", IntegerType(), required=False))
pyarrow_schema = schema_to_pyarrow(schema, metadata={ICEBERG_SCHEMA: bytes(schema.model_dump_json(), UTF8)})

tasks = []
total_expected = 0
for i in range(3):
num_rows = (i + 1) * 100 # 100, 200, 300
total_expected += num_rows
table = pa.Table.from_arrays([pa.array(list(range(num_rows)))], schema=pyarrow_schema)
data_file = _write_table_to_data_file(f"{tmpdir}/multi_{i}.parquet", pyarrow_schema, table)
data_file.spec_id = 0
tasks.append(FileScanTask(data_file=data_file))

scan = ArrowScan(
table_metadata=TableMetadataV2(
location="file://a/b/",
last_column_id=1,
format_version=2,
schemas=[schema],
partition_specs=[PartitionSpec()],
),
io=PyArrowFileIO(),
projected_schema=schema,
row_filter=AlwaysTrue(),
case_sensitive=True,
)

batches = list(scan.to_record_batch_stream(tasks))
total_rows = sum(len(b) for b in batches)
assert total_rows == total_expected # 600 rows total
Loading