|
18 | 18 |
|
19 | 19 | from __future__ import annotations |
20 | 20 |
|
| 21 | +from importlib import import_module, util |
21 | 22 | from typing import TYPE_CHECKING, Any |
22 | 23 |
|
23 | 24 | from datafusion._internal import EXPECTED_PROVIDER_MSG |
24 | 25 |
|
| 26 | +_PYARROW_DATASET_TYPES: tuple[type[Any], ...] |
| 27 | +_dataset_spec = util.find_spec("pyarrow.dataset") |
| 28 | +if _dataset_spec is None: # pragma: no cover - optional dependency at runtime |
| 29 | + _PYARROW_DATASET_TYPES = () |
| 30 | +else: # pragma: no cover - exercised in environments with pyarrow installed |
| 31 | + _dataset_module = import_module("pyarrow.dataset") |
| 32 | + dataset_base = getattr(_dataset_module, "Dataset", None) |
| 33 | + dataset_types: set[type[Any]] = set() |
| 34 | + if isinstance(dataset_base, type): |
| 35 | + dataset_types.add(dataset_base) |
| 36 | + for value in vars(_dataset_module).values(): |
| 37 | + if isinstance(value, type) and issubclass(value, dataset_base): |
| 38 | + dataset_types.add(value) |
| 39 | + _PYARROW_DATASET_TYPES = tuple(dataset_types) |
| 40 | + |
25 | 41 | if TYPE_CHECKING: # pragma: no cover - imported for typing only |
26 | 42 | from datafusion import TableProvider |
27 | 43 | from datafusion.catalog import Table |
@@ -54,6 +70,9 @@ def _normalize_table_provider( |
54 | 70 | if isinstance(table, _TableProvider): |
55 | 71 | return table._table_provider |
56 | 72 |
|
| 73 | + if _PYARROW_DATASET_TYPES and isinstance(table, _PYARROW_DATASET_TYPES): |
| 74 | + return table |
| 75 | + |
57 | 76 | provider_factory = getattr(table, "__datafusion_table_provider__", None) |
58 | 77 | if callable(provider_factory): |
59 | 78 | return table |
|
0 commit comments