Skip to content

Commit 5300398

Browse files
committed
test: add test for schema registration with pyarrow dataset
1 parent 76fcc08 commit 5300398

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

python/datafusion/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,26 @@
1818

1919
from __future__ import annotations
2020

21+
from importlib import import_module, util
2122
from typing import TYPE_CHECKING, Any
2223

2324
from datafusion._internal import EXPECTED_PROVIDER_MSG
2425

26+
_PYARROW_DATASET_TYPES: tuple[type[Any], ...]
27+
_dataset_spec = util.find_spec("pyarrow.dataset")
28+
if _dataset_spec is None: # pragma: no cover - optional dependency at runtime
29+
_PYARROW_DATASET_TYPES = ()
30+
else: # pragma: no cover - exercised in environments with pyarrow installed
31+
_dataset_module = import_module("pyarrow.dataset")
32+
dataset_base = getattr(_dataset_module, "Dataset", None)
33+
dataset_types: set[type[Any]] = set()
34+
if isinstance(dataset_base, type):
35+
dataset_types.add(dataset_base)
36+
for value in vars(_dataset_module).values():
37+
if isinstance(value, type) and issubclass(value, dataset_base):
38+
dataset_types.add(value)
39+
_PYARROW_DATASET_TYPES = tuple(dataset_types)
40+
2541
if TYPE_CHECKING: # pragma: no cover - imported for typing only
2642
from datafusion import TableProvider
2743
from datafusion.catalog import Table
@@ -54,6 +70,9 @@ def _normalize_table_provider(
5470
if isinstance(table, _TableProvider):
5571
return table._table_provider
5672

73+
if _PYARROW_DATASET_TYPES and isinstance(table, _PYARROW_DATASET_TYPES):
74+
return table
75+
5776
provider_factory = getattr(table, "__datafusion_table_provider__", None)
5877
if callable(provider_factory):
5978
return table

python/tests/test_catalog.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,28 @@ def test_python_table_provider(ctx: SessionContext):
164164
assert schema.table_names() == {"table4"}
165165

166166

167+
def test_schema_register_table_with_pyarrow_dataset(ctx: SessionContext):
168+
schema = ctx.catalog().schema()
169+
batch = pa.RecordBatch.from_arrays(
170+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
171+
names=["a", "b"],
172+
)
173+
dataset = ds.dataset([batch])
174+
table_name = "pa_dataset"
175+
176+
try:
177+
schema.register_table(table_name, dataset)
178+
assert table_name in schema.table_names()
179+
180+
result = ctx.sql(f"SELECT a, b FROM {table_name}").collect()
181+
182+
assert len(result) == 1
183+
assert result[0].column(0) == pa.array([1, 2, 3])
184+
assert result[0].column(1) == pa.array([4, 5, 6])
185+
finally:
186+
schema.deregister_table(table_name)
187+
188+
167189
def test_schema_register_table_with_dataframe_errors(ctx: SessionContext):
168190
schema = ctx.catalog().schema()
169191
df = ctx.from_pydict({"a": [1]})

0 commit comments

Comments
 (0)