@@ -21,10 +21,12 @@ use std::sync::Arc;
2121
2222use async_trait:: async_trait;
2323use datafusion:: catalog:: {
24- CatalogProvider , MemoryCatalogProvider , MemorySchemaProvider , SchemaProvider ,
24+ CatalogProvider , CatalogProviderList , MemoryCatalogProvider , MemoryCatalogProviderList ,
25+ MemorySchemaProvider , SchemaProvider ,
2526} ;
2627use datafusion:: common:: DataFusionError ;
2728use datafusion:: datasource:: TableProvider ;
29+ use datafusion_ffi:: catalog_provider:: FFI_CatalogProvider ;
2830use datafusion_ffi:: proto:: logical_extension_codec:: FFI_LogicalExtensionCodec ;
2931use datafusion_ffi:: schema_provider:: FFI_SchemaProvider ;
3032use pyo3:: exceptions:: PyKeyError ;
@@ -40,6 +42,18 @@ use crate::utils::{
4042 wait_for_future,
4143} ;
4244
45+ #[ pyclass(
46+ frozen,
47+ name = "RawCatalogList" ,
48+ module = "datafusion.catalog" ,
49+ subclass
50+ ) ]
51+ #[ derive( Clone ) ]
52+ pub struct PyCatalogList {
53+ pub catalog_list : Arc < dyn CatalogProviderList > ,
54+ codec : Arc < FFI_LogicalExtensionCodec > ,
55+ }
56+
4357#[ pyclass( frozen, name = "RawCatalog" , module = "datafusion.catalog" , subclass) ]
4458#[ derive( Clone ) ]
4559pub struct PyCatalog {
@@ -72,6 +86,77 @@ impl PySchema {
7286 }
7387}
7488
89+ #[ pymethods]
90+ impl PyCatalogList {
91+ #[ new]
92+ pub fn new (
93+ py : Python ,
94+ catalog_list : Py < PyAny > ,
95+ session : Option < Bound < PyAny > > ,
96+ ) -> PyResult < Self > {
97+ let codec = extract_logical_extension_codec ( py, session) ?;
98+ let catalog_list = Arc :: new ( RustWrappedPyCatalogProviderList :: new (
99+ catalog_list,
100+ codec. clone ( ) ,
101+ ) ) as Arc < dyn CatalogProviderList > ;
102+ Ok ( Self {
103+ catalog_list,
104+ codec,
105+ } )
106+ }
107+
108+ #[ staticmethod]
109+ pub fn memory_catalog_list ( py : Python , session : Option < Bound < PyAny > > ) -> PyResult < Self > {
110+ let codec = extract_logical_extension_codec ( py, session) ?;
111+ let catalog_list =
112+ Arc :: new ( MemoryCatalogProviderList :: default ( ) ) as Arc < dyn CatalogProviderList > ;
113+ Ok ( Self {
114+ catalog_list,
115+ codec,
116+ } )
117+ }
118+
119+ pub fn catalog_names ( & self ) -> HashSet < String > {
120+ self . catalog_list . catalog_names ( ) . into_iter ( ) . collect ( )
121+ }
122+
123+ #[ pyo3( signature = ( name="public" ) ) ]
124+ pub fn catalog ( & self , name : & str ) -> PyResult < Py < PyAny > > {
125+ let catalog = self
126+ . catalog_list
127+ . catalog ( name)
128+ . ok_or ( PyKeyError :: new_err ( format ! (
129+ "Schema with name {name} doesn't exist."
130+ ) ) ) ?;
131+
132+ Python :: attach ( |py| {
133+ match catalog
134+ . as_any ( )
135+ . downcast_ref :: < RustWrappedPyCatalogProvider > ( )
136+ {
137+ Some ( wrapped_catalog) => Ok ( wrapped_catalog. catalog_provider . clone_ref ( py) ) ,
138+ None => PyCatalog :: new_from_parts ( catalog, self . codec . clone ( ) ) . into_py_any ( py) ,
139+ }
140+ } )
141+ }
142+
143+ pub fn register_catalog ( & self , name : & str , catalog_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
144+ let provider = extract_catalog_provider_from_pyobj ( catalog_provider, self . codec . as_ref ( ) ) ?;
145+
146+ let _ = self
147+ . catalog_list
148+ . register_catalog ( name. to_owned ( ) , provider) ;
149+
150+ Ok ( ( ) )
151+ }
152+
153+ pub fn __repr__ ( & self ) -> PyResult < String > {
154+ let mut names: Vec < String > = self . catalog_names ( ) . into_iter ( ) . collect ( ) ;
155+ names. sort ( ) ;
156+ Ok ( format ! ( "CatalogList(catalog_names=[{}])" , names. join( ", " ) ) )
157+ }
158+ }
159+
75160#[ pymethods]
76161impl PyCatalog {
77162 #[ new]
@@ -442,6 +527,137 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
442527 }
443528}
444529
530+ #[ derive( Debug ) ]
531+ pub ( crate ) struct RustWrappedPyCatalogProviderList {
532+ pub ( crate ) catalog_provider_list : Py < PyAny > ,
533+ codec : Arc < FFI_LogicalExtensionCodec > ,
534+ }
535+
536+ impl RustWrappedPyCatalogProviderList {
537+ pub fn new ( catalog_provider_list : Py < PyAny > , codec : Arc < FFI_LogicalExtensionCodec > ) -> Self {
538+ Self {
539+ catalog_provider_list,
540+ codec,
541+ }
542+ }
543+
544+ fn catalog_inner ( & self , name : & str ) -> PyResult < Option < Arc < dyn CatalogProvider > > > {
545+ Python :: attach ( |py| {
546+ let provider = self . catalog_provider_list . bind ( py) ;
547+
548+ let py_schema = provider. call_method1 ( "catalog" , ( name, ) ) ?;
549+ if py_schema. is_none ( ) {
550+ return Ok ( None ) ;
551+ }
552+
553+ extract_catalog_provider_from_pyobj ( py_schema, self . codec . as_ref ( ) ) . map ( Some )
554+ } )
555+ }
556+ }
557+
558+ #[ async_trait]
559+ impl CatalogProviderList for RustWrappedPyCatalogProviderList {
560+ fn as_any ( & self ) -> & dyn Any {
561+ self
562+ }
563+
564+ fn catalog_names ( & self ) -> Vec < String > {
565+ Python :: attach ( |py| {
566+ let provider = self . catalog_provider_list . bind ( py) ;
567+ provider
568+ . getattr ( "catalog_names" )
569+ . and_then ( |names| names. extract :: < Vec < String > > ( ) )
570+ . unwrap_or_else ( |err| {
571+ log:: error!( "Unable to get catalog_names: {err}" ) ;
572+ Vec :: default ( )
573+ } )
574+ } )
575+ }
576+
577+ fn catalog ( & self , name : & str ) -> Option < Arc < dyn CatalogProvider > > {
578+ self . catalog_inner ( name) . unwrap_or_else ( |err| {
579+ log:: error!( "CatalogProvider catalog returned error: {err}" ) ;
580+ None
581+ } )
582+ }
583+
584+ fn register_catalog (
585+ & self ,
586+ name : String ,
587+ catalog : Arc < dyn CatalogProvider > ,
588+ ) -> Option < Arc < dyn CatalogProvider > > {
589+ Python :: attach ( |py| {
590+ let py_catalog = match catalog
591+ . as_any ( )
592+ . downcast_ref :: < RustWrappedPyCatalogProvider > ( )
593+ {
594+ Some ( wrapped_schema) => wrapped_schema. catalog_provider . as_any ( ) . clone_ref ( py) ,
595+ None => {
596+ match PyCatalog :: new_from_parts ( catalog, self . codec . clone ( ) ) . into_py_any ( py) {
597+ Ok ( c) => c,
598+ Err ( err) => {
599+ log:: error!(
600+ "register_catalog returned error during conversion to PyAny: {err}"
601+ ) ;
602+ return None ;
603+ }
604+ }
605+ }
606+ } ;
607+
608+ let provider = self . catalog_provider_list . bind ( py) ;
609+ let catalog = match provider. call_method1 ( "register_catalog" , ( name, py_catalog) ) {
610+ Ok ( c) => c,
611+ Err ( err) => {
612+ log:: error!( "register_catalog returned error: {err}" ) ;
613+ return None ;
614+ }
615+ } ;
616+ if catalog. is_none ( ) {
617+ return None ;
618+ }
619+
620+ let catalog = Arc :: new ( RustWrappedPyCatalogProvider :: new (
621+ catalog. into ( ) ,
622+ self . codec . clone ( ) ,
623+ ) ) as Arc < dyn CatalogProvider > ;
624+
625+ Some ( catalog)
626+ } )
627+ }
628+ }
629+
630+ fn extract_catalog_provider_from_pyobj (
631+ mut catalog_provider : Bound < PyAny > ,
632+ codec : & FFI_LogicalExtensionCodec ,
633+ ) -> PyResult < Arc < dyn CatalogProvider > > {
634+ if catalog_provider. hasattr ( "__datafusion_catalog_provider__" ) ? {
635+ let py = catalog_provider. py ( ) ;
636+ let codec_capsule = create_logical_extension_capsule ( py, codec) ?;
637+ catalog_provider = catalog_provider
638+ . getattr ( "__datafusion_catalog_provider__" ) ?
639+ . call1 ( ( codec_capsule, ) ) ?;
640+ }
641+
642+ let provider = if let Ok ( capsule) = catalog_provider. downcast :: < PyCapsule > ( ) {
643+ validate_pycapsule ( capsule, "datafusion_catalog_provider" ) ?;
644+
645+ let provider = unsafe { capsule. reference :: < FFI_CatalogProvider > ( ) } ;
646+ let provider: Arc < dyn CatalogProvider + Send > = provider. into ( ) ;
647+ provider as Arc < dyn CatalogProvider >
648+ } else {
649+ match catalog_provider. extract :: < PyCatalog > ( ) {
650+ Ok ( py_catalog) => py_catalog. catalog ,
651+ Err ( _) => Arc :: new ( RustWrappedPyCatalogProvider :: new (
652+ catalog_provider. into ( ) ,
653+ Arc :: new ( codec. clone ( ) ) ,
654+ ) ) as Arc < dyn CatalogProvider > ,
655+ }
656+ } ;
657+
658+ Ok ( provider)
659+ }
660+
445661fn extract_schema_provider_from_pyobj (
446662 mut schema_provider : Bound < PyAny > ,
447663 codec : & FFI_LogicalExtensionCodec ,
0 commit comments