Skip to content

Commit 8fc9436

Browse files
authored
feat: add CatalogProviderList support (#1363)
* Implement catalog provider list * Flush out python side and add unit test * Add FFI test for catalog provider list * Update type hints * Update unit test to add a different type of catalog to the catalog list
1 parent b555df5 commit 8fc9436

File tree

8 files changed

+496
-20
lines changed

8 files changed

+496
-20
lines changed

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytest
2323
from datafusion import SessionContext, Table
2424
from datafusion.catalog import Schema
25-
from datafusion_ffi_example import MyCatalogProvider
25+
from datafusion_ffi_example import MyCatalogProvider, MyCatalogProviderList
2626

2727

2828
def create_test_dataset() -> Table:
@@ -35,6 +35,30 @@ def create_test_dataset() -> Table:
3535
return Table(dataset)
3636

3737

38+
@pytest.mark.parametrize("inner_capsule", [True, False])
39+
def test_ffi_catalog_provider_list(inner_capsule: bool) -> None:
40+
"""Test basic FFI CatalogProviderList functionality."""
41+
ctx = SessionContext()
42+
43+
# Register FFI catalog
44+
catalog_provider_list = MyCatalogProviderList()
45+
if inner_capsule:
46+
catalog_provider_list = (
47+
catalog_provider_list.__datafusion_catalog_provider_list__(ctx)
48+
)
49+
50+
ctx.register_catalog_provider_list(catalog_provider_list)
51+
52+
# Verify the catalog exists
53+
catalog = ctx.catalog("auto_ffi_catalog")
54+
schema_names = catalog.names()
55+
assert "my_schema" in schema_names
56+
57+
ctx.register_catalog_provider("second", MyCatalogProvider())
58+
59+
assert ctx.catalog_names() == {"auto_ffi_catalog", "second"}
60+
61+
3862
@pytest.mark.parametrize("inner_capsule", [True, False])
3963
def test_ffi_catalog_provider_basic(inner_capsule: bool) -> None:
4064
"""Test basic FFI CatalogProvider functionality."""

examples/datafusion-ffi-example/src/catalog_provider.rs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ use std::sync::Arc;
2222
use arrow::datatypes::Schema;
2323
use async_trait::async_trait;
2424
use datafusion_catalog::{
25-
CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
26-
TableProvider,
25+
CatalogProvider, CatalogProviderList, MemTable, MemoryCatalogProvider,
26+
MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider, TableProvider,
2727
};
2828
use datafusion_common::error::{DataFusionError, Result};
2929
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
30+
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
3031
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
3132
use pyo3::types::PyCapsule;
3233
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
@@ -203,3 +204,67 @@ impl MyCatalogProvider {
203204
PyCapsule::new(py, provider, Some(name))
204205
}
205206
}
207+
208+
/// This catalog provider list is intended only for unit tests.
209+
/// It pre-populates with a single catalog.
210+
#[pyclass(
211+
name = "MyCatalogProviderList",
212+
module = "datafusion_ffi_example",
213+
subclass
214+
)]
215+
#[derive(Debug, Clone)]
216+
pub(crate) struct MyCatalogProviderList {
217+
inner: Arc<MemoryCatalogProviderList>,
218+
}
219+
220+
impl CatalogProviderList for MyCatalogProviderList {
221+
fn as_any(&self) -> &dyn Any {
222+
self
223+
}
224+
225+
fn catalog_names(&self) -> Vec<String> {
226+
self.inner.catalog_names()
227+
}
228+
229+
fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
230+
self.inner.catalog(name)
231+
}
232+
233+
fn register_catalog(
234+
&self,
235+
name: String,
236+
catalog: Arc<dyn CatalogProvider>,
237+
) -> Option<Arc<dyn CatalogProvider>> {
238+
self.inner.register_catalog(name, catalog)
239+
}
240+
}
241+
242+
#[pymethods]
243+
impl MyCatalogProviderList {
244+
#[new]
245+
pub fn new() -> PyResult<Self> {
246+
let inner = Arc::new(MemoryCatalogProviderList::new());
247+
248+
inner.register_catalog(
249+
"auto_ffi_catalog".to_owned(),
250+
Arc::new(MyCatalogProvider::new()?),
251+
);
252+
253+
Ok(Self { inner })
254+
}
255+
256+
pub fn __datafusion_catalog_provider_list__<'py>(
257+
&self,
258+
py: Python<'py>,
259+
session: Bound<PyAny>,
260+
) -> PyResult<Bound<'py, PyCapsule>> {
261+
let name = cr"datafusion_catalog_provider_list".into();
262+
263+
let provider = Arc::clone(&self.inner) as Arc<dyn CatalogProviderList + Send>;
264+
265+
let codec = ffi_logical_codec_from_pycapsule(session)?;
266+
let provider = FFI_CatalogProviderList::new_with_ffi_codec(provider, None, codec);
267+
268+
PyCapsule::new(py, provider, Some(name))
269+
}
270+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use pyo3::prelude::*;
1919

2020
use crate::aggregate_udf::MySumUDF;
21-
use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider};
21+
use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogProviderList};
2222
use crate::scalar_udf::IsNullUDF;
2323
use crate::table_function::MyTableFunction;
2424
use crate::table_provider::MyTableProvider;
@@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3737
m.add_class::<MyTableProvider>()?;
3838
m.add_class::<MyTableFunction>()?;
3939
m.add_class::<MyCatalogProvider>()?;
40+
m.add_class::<MyCatalogProviderList>()?;
4041
m.add_class::<FixedSchemaProvider>()?;
4142
m.add_class::<IsNullUDF>()?;
4243
m.add_class::<MySumUDF>()?;

python/datafusion/catalog.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,61 @@
3838

3939
__all__ = [
4040
"Catalog",
41+
"CatalogList",
4142
"CatalogProvider",
43+
"CatalogProviderList",
4244
"Schema",
4345
"SchemaProvider",
4446
"Table",
4547
]
4648

4749

50+
class CatalogList:
51+
"""DataFusion data catalog list."""
52+
53+
def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> None:
54+
"""This constructor is not typically called by the end user."""
55+
self.catalog_list = catalog_list
56+
57+
def __repr__(self) -> str:
58+
"""Print a string representation of the catalog list."""
59+
return self.catalog_list.__repr__()
60+
61+
def names(self) -> set[str]:
62+
"""This is an alias for `catalog_names`."""
63+
return self.catalog_names()
64+
65+
def catalog_names(self) -> set[str]:
66+
"""Returns the list of schemas in this catalog."""
67+
return self.catalog_list.catalog_names()
68+
69+
@staticmethod
70+
def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
71+
"""Create an in-memory catalog provider list."""
72+
catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
73+
return CatalogList(catalog_list)
74+
75+
def catalog(self, name: str = "datafusion") -> Catalog:
76+
"""Returns the catalog with the given ``name`` from this catalog."""
77+
catalog = self.catalog_list.catalog(name)
78+
79+
return (
80+
Catalog(catalog)
81+
if isinstance(catalog, df_internal.catalog.RawCatalog)
82+
else catalog
83+
)
84+
85+
def register_catalog(
86+
self,
87+
name: str,
88+
catalog: Catalog | CatalogProvider | CatalogProviderExportable,
89+
) -> Catalog | None:
90+
"""Register a catalog with this catalog list."""
91+
if isinstance(catalog, Catalog):
92+
return self.catalog_list.register_catalog(name, catalog.catalog)
93+
return self.catalog_list.register_catalog(name, catalog)
94+
95+
4896
class Catalog:
4997
"""DataFusion data catalog."""
5098

@@ -195,6 +243,40 @@ def kind(self) -> str:
195243
return self._inner.kind
196244

197245

246+
class CatalogProviderList(ABC):
247+
"""Abstract class for defining a Python based Catalog Provider List."""
248+
249+
@abstractmethod
250+
def catalog_names(self) -> set[str]:
251+
"""Set of the names of all catalogs in this catalog list."""
252+
...
253+
254+
@abstractmethod
255+
def catalog(
256+
self, name: str
257+
) -> CatalogProviderExportable | CatalogProvider | Catalog | None:
258+
"""Retrieve a specific catalog from this catalog list."""
259+
...
260+
261+
def register_catalog( # noqa: B027
262+
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
263+
) -> None:
264+
"""Add a catalog to this catalog list.
265+
266+
This method is optional. If your catalog provides a fixed list of catalogs, you
267+
do not need to implement this method.
268+
"""
269+
270+
271+
class CatalogProviderListExportable(Protocol):
272+
"""Type hint for object that has __datafusion_catalog_provider_list__ PyCapsule.
273+
274+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
275+
"""
276+
277+
def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...
278+
279+
198280
class CatalogProvider(ABC):
199281
"""Abstract class for defining a Python based Catalog Provider."""
200282

@@ -229,6 +311,15 @@ def deregister_schema(self, name: str, cascade: bool) -> None: # noqa: B027
229311
"""
230312

231313

314+
class CatalogProviderExportable(Protocol):
315+
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
316+
317+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
318+
"""
319+
320+
def __datafusion_catalog_provider__(self, session: Any) -> object: ...
321+
322+
232323
class SchemaProvider(ABC):
233324
"""Abstract class for defining a Python based Schema Provider."""
234325

python/datafusion/context.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@
3131

3232
import pyarrow as pa
3333

34-
from datafusion.catalog import Catalog
34+
from datafusion.catalog import (
35+
Catalog,
36+
CatalogList,
37+
CatalogProviderExportable,
38+
CatalogProviderList,
39+
CatalogProviderListExportable,
40+
)
3541
from datafusion.dataframe import DataFrame
3642
from datafusion.expr import sort_list_to_raw_sort_list
3743
from datafusion.options import (
@@ -96,15 +102,6 @@ class TableProviderExportable(Protocol):
96102
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105
97103

98104

99-
class CatalogProviderExportable(Protocol):
100-
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
101-
102-
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
103-
"""
104-
105-
def __datafusion_catalog_provider__(self, session: Any) -> object: ... # noqa: D105
106-
107-
108105
class SessionConfig:
109106
"""Session configuration options."""
110107

@@ -837,6 +834,16 @@ def catalog_names(self) -> set[str]:
837834
"""Returns the list of catalogs in this context."""
838835
return self.ctx.catalog_names()
839836

837+
def register_catalog_provider_list(
838+
self,
839+
provider: CatalogProviderListExportable | CatalogProviderList | CatalogList,
840+
) -> None:
841+
"""Register a catalog provider list."""
842+
if isinstance(provider, CatalogList):
843+
self.ctx.register_catalog_provider_list(provider.catalog)
844+
else:
845+
self.ctx.register_catalog_provider_list(provider)
846+
840847
def register_catalog_provider(
841848
self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog
842849
) -> None:

python/tests/test_catalog.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from typing import TYPE_CHECKING
20+
1921
import datafusion as dfn
2022
import pyarrow as pa
2123
import pyarrow.dataset as ds
2224
import pytest
23-
from datafusion import SessionContext, Table, udtf
25+
from datafusion import Catalog, SessionContext, Table, udtf
26+
27+
if TYPE_CHECKING:
28+
from datafusion.catalog import CatalogProvider, CatalogProviderExportable
2429

2530

2631
# Note we take in `database` as a variable even though we don't use
@@ -93,6 +98,34 @@ def deregister_schema(self, name, cascade: bool):
9398
del self.schemas[name]
9499

95100

101+
class CustomCatalogProviderList(dfn.catalog.CatalogProviderList):
102+
def __init__(self):
103+
self.catalogs = {"my_catalog": CustomCatalogProvider()}
104+
105+
def catalog_names(self) -> set[str]:
106+
return set(self.catalogs.keys())
107+
108+
def catalog(self, name: str) -> Catalog | None:
109+
return self.catalogs[name]
110+
111+
def register_catalog(
112+
self, name: str, catalog: CatalogProviderExportable | CatalogProvider | Catalog
113+
) -> None:
114+
self.catalogs[name] = catalog
115+
116+
117+
def test_python_catalog_provider_list(ctx: SessionContext):
118+
ctx.register_catalog_provider_list(CustomCatalogProviderList())
119+
120+
# Ensure `datafusion` catalog does not exist since
121+
# we replaced the catalog list
122+
assert ctx.catalog_names() == {"my_catalog"}
123+
124+
# Ensure registering works
125+
ctx.register_catalog_provider("second_catalog", Catalog.memory_catalog())
126+
assert ctx.catalog_names() == {"my_catalog", "second_catalog"}
127+
128+
96129
def test_python_catalog_provider(ctx: SessionContext):
97130
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
98131

0 commit comments

Comments
 (0)