Skip to content

Commit 50974af

Browse files
committed
Implement catalog provider list
1 parent e7f5867 commit 50974af

File tree

2 files changed

+256
-3
lines changed

2 files changed

+256
-3
lines changed

src/catalog.rs

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ use std::sync::Arc;
2121

2222
use async_trait::async_trait;
2323
use datafusion::catalog::{
24-
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
24+
CatalogProvider, CatalogProviderList, MemoryCatalogProvider, MemoryCatalogProviderList,
25+
MemorySchemaProvider, SchemaProvider,
2526
};
2627
use datafusion::common::DataFusionError;
2728
use datafusion::datasource::TableProvider;
29+
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
2830
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
2931
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
3032
use pyo3::exceptions::PyKeyError;
@@ -40,6 +42,18 @@ use crate::utils::{
4042
wait_for_future,
4143
};
4244

45+
#[pyclass(
46+
frozen,
47+
name = "RawCatalogList",
48+
module = "datafusion.catalog",
49+
subclass
50+
)]
51+
#[derive(Clone)]
52+
pub struct PyCatalogList {
53+
pub catalog_list: Arc<dyn CatalogProviderList>,
54+
codec: Arc<FFI_LogicalExtensionCodec>,
55+
}
56+
4357
#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
4458
#[derive(Clone)]
4559
pub struct PyCatalog {
@@ -72,6 +86,77 @@ impl PySchema {
7286
}
7387
}
7488

89+
#[pymethods]
90+
impl PyCatalogList {
91+
#[new]
92+
pub fn new(
93+
py: Python,
94+
catalog_list: Py<PyAny>,
95+
session: Option<Bound<PyAny>>,
96+
) -> PyResult<Self> {
97+
let codec = extract_logical_extension_codec(py, session)?;
98+
let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
99+
catalog_list,
100+
codec.clone(),
101+
)) as Arc<dyn CatalogProviderList>;
102+
Ok(Self {
103+
catalog_list,
104+
codec,
105+
})
106+
}
107+
108+
#[staticmethod]
109+
pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> PyResult<Self> {
110+
let codec = extract_logical_extension_codec(py, session)?;
111+
let catalog_list =
112+
Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn CatalogProviderList>;
113+
Ok(Self {
114+
catalog_list,
115+
codec,
116+
})
117+
}
118+
119+
pub fn catalog_names(&self) -> HashSet<String> {
120+
self.catalog_list.catalog_names().into_iter().collect()
121+
}
122+
123+
#[pyo3(signature = (name="public"))]
124+
pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
125+
let catalog = self
126+
.catalog_list
127+
.catalog(name)
128+
.ok_or(PyKeyError::new_err(format!(
129+
"Schema with name {name} doesn't exist."
130+
)))?;
131+
132+
Python::attach(|py| {
133+
match catalog
134+
.as_any()
135+
.downcast_ref::<RustWrappedPyCatalogProvider>()
136+
{
137+
Some(wrapped_catalog) => Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
138+
None => PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py),
139+
}
140+
})
141+
}
142+
143+
pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, PyAny>) -> PyResult<()> {
144+
let provider = extract_catalog_provider_from_pyobj(catalog_provider, self.codec.as_ref())?;
145+
146+
let _ = self
147+
.catalog_list
148+
.register_catalog(name.to_owned(), provider);
149+
150+
Ok(())
151+
}
152+
153+
pub fn __repr__(&self) -> PyResult<String> {
154+
let mut names: Vec<String> = self.catalog_names().into_iter().collect();
155+
names.sort();
156+
Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
157+
}
158+
}
159+
75160
#[pymethods]
76161
impl PyCatalog {
77162
#[new]
@@ -442,6 +527,137 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
442527
}
443528
}
444529

530+
#[derive(Debug)]
531+
pub(crate) struct RustWrappedPyCatalogProviderList {
532+
pub(crate) catalog_provider_list: Py<PyAny>,
533+
codec: Arc<FFI_LogicalExtensionCodec>,
534+
}
535+
536+
impl RustWrappedPyCatalogProviderList {
537+
pub fn new(catalog_provider_list: Py<PyAny>, codec: Arc<FFI_LogicalExtensionCodec>) -> Self {
538+
Self {
539+
catalog_provider_list,
540+
codec,
541+
}
542+
}
543+
544+
fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn CatalogProvider>>> {
545+
Python::attach(|py| {
546+
let provider = self.catalog_provider_list.bind(py);
547+
548+
let py_schema = provider.call_method1("catalog", (name,))?;
549+
if py_schema.is_none() {
550+
return Ok(None);
551+
}
552+
553+
extract_catalog_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
554+
})
555+
}
556+
}
557+
558+
#[async_trait]
559+
impl CatalogProviderList for RustWrappedPyCatalogProviderList {
560+
fn as_any(&self) -> &dyn Any {
561+
self
562+
}
563+
564+
fn catalog_names(&self) -> Vec<String> {
565+
Python::attach(|py| {
566+
let provider = self.catalog_provider_list.bind(py);
567+
provider
568+
.getattr("catalog_names")
569+
.and_then(|names| names.extract::<Vec<String>>())
570+
.unwrap_or_else(|err| {
571+
log::error!("Unable to get catalog_names: {err}");
572+
Vec::default()
573+
})
574+
})
575+
}
576+
577+
fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
578+
self.catalog_inner(name).unwrap_or_else(|err| {
579+
log::error!("CatalogProvider catalog returned error: {err}");
580+
None
581+
})
582+
}
583+
584+
fn register_catalog(
585+
&self,
586+
name: String,
587+
catalog: Arc<dyn CatalogProvider>,
588+
) -> Option<Arc<dyn CatalogProvider>> {
589+
Python::attach(|py| {
590+
let py_catalog = match catalog
591+
.as_any()
592+
.downcast_ref::<RustWrappedPyCatalogProvider>()
593+
{
594+
Some(wrapped_schema) => wrapped_schema.catalog_provider.as_any().clone_ref(py),
595+
None => {
596+
match PyCatalog::new_from_parts(catalog, self.codec.clone()).into_py_any(py) {
597+
Ok(c) => c,
598+
Err(err) => {
599+
log::error!(
600+
"register_catalog returned error during conversion to PyAny: {err}"
601+
);
602+
return None;
603+
}
604+
}
605+
}
606+
};
607+
608+
let provider = self.catalog_provider_list.bind(py);
609+
let catalog = match provider.call_method1("register_catalog", (name, py_catalog)) {
610+
Ok(c) => c,
611+
Err(err) => {
612+
log::error!("register_catalog returned error: {err}");
613+
return None;
614+
}
615+
};
616+
if catalog.is_none() {
617+
return None;
618+
}
619+
620+
let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
621+
catalog.into(),
622+
self.codec.clone(),
623+
)) as Arc<dyn CatalogProvider>;
624+
625+
Some(catalog)
626+
})
627+
}
628+
}
629+
630+
fn extract_catalog_provider_from_pyobj(
631+
mut catalog_provider: Bound<PyAny>,
632+
codec: &FFI_LogicalExtensionCodec,
633+
) -> PyResult<Arc<dyn CatalogProvider>> {
634+
if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
635+
let py = catalog_provider.py();
636+
let codec_capsule = create_logical_extension_capsule(py, codec)?;
637+
catalog_provider = catalog_provider
638+
.getattr("__datafusion_catalog_provider__")?
639+
.call1((codec_capsule,))?;
640+
}
641+
642+
let provider = if let Ok(capsule) = catalog_provider.downcast::<PyCapsule>() {
643+
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
644+
645+
let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
646+
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
647+
provider as Arc<dyn CatalogProvider>
648+
} else {
649+
match catalog_provider.extract::<PyCatalog>() {
650+
Ok(py_catalog) => py_catalog.catalog,
651+
Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
652+
catalog_provider.into(),
653+
Arc::new(codec.clone()),
654+
)) as Arc<dyn CatalogProvider>,
655+
}
656+
};
657+
658+
Ok(provider)
659+
}
660+
445661
fn extract_schema_provider_from_pyobj(
446662
mut schema_provider: Bound<PyAny>,
447663
codec: &FFI_LogicalExtensionCodec,

src/context.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::pyarrow::FromPyArrow;
2626
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
2727
use datafusion::arrow::pyarrow::PyArrowType;
2828
use datafusion::arrow::record_batch::RecordBatch;
29-
use datafusion::catalog::CatalogProvider;
29+
use datafusion::catalog::{CatalogProvider, CatalogProviderList};
3030
use datafusion::common::{exec_err, ScalarValue, TableReference};
3131
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
3232
use datafusion::datasource::file_format::parquet::ParquetFormat;
@@ -47,6 +47,7 @@ use datafusion::prelude::{
4747
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
4848
};
4949
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
50+
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
5051
use datafusion_ffi::execution::FFI_TaskContextProvider;
5152
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
5253
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
@@ -58,7 +59,9 @@ use pyo3::IntoPyObjectExt;
5859
use url::Url;
5960
use uuid::Uuid;
6061

61-
use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
62+
use crate::catalog::{
63+
PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList,
64+
};
6265
use crate::common::data_type::PyScalarValue;
6366
use crate::dataframe::PyDataFrame;
6467
use crate::dataset::Dataset;
@@ -627,6 +630,40 @@ impl PySessionContext {
627630
Ok(())
628631
}
629632

633+
pub fn register_catalog_provider_list(
634+
&self,
635+
mut provider: Bound<PyAny>,
636+
) -> PyDataFusionResult<()> {
637+
if provider.hasattr("__datafusion_catalog_provider_list__")? {
638+
let py = provider.py();
639+
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
640+
provider = provider
641+
.getattr("__datafusion_catalog_provider_list__")?
642+
.call1((codec_capsule,))?;
643+
}
644+
645+
let provider =
646+
if let Ok(capsule) = provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
647+
validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;
648+
649+
let provider = unsafe { capsule.reference::<FFI_CatalogProviderList>() };
650+
let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
651+
provider as Arc<dyn CatalogProviderList>
652+
} else {
653+
match provider.extract::<PyCatalogList>() {
654+
Ok(py_catalog_list) => py_catalog_list.catalog_list,
655+
Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new(
656+
provider.into(),
657+
Arc::clone(&self.logical_codec),
658+
)) as Arc<dyn CatalogProviderList>,
659+
}
660+
};
661+
662+
self.ctx.register_catalog_list(provider);
663+
664+
Ok(())
665+
}
666+
630667
pub fn register_catalog_provider(
631668
&self,
632669
name: &str,

0 commit comments

Comments
 (0)