@@ -33,6 +33,7 @@ use pyo3::prelude::*;
3333
3434use crate :: catalog:: { PyCatalog , PyTable , RustWrappedPyCatalogProvider } ;
3535use crate :: dataframe:: PyDataFrame ;
36+ use crate :: dataframe:: PyTableProvider ;
3637use crate :: dataset:: Dataset ;
3738use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionResult } ;
3839use crate :: expr:: sort_expr:: PySortExpr ;
@@ -417,12 +418,7 @@ impl PySessionContext {
417418 . with_listing_options ( options)
418419 . with_schema ( resolved_schema) ;
419420 let table = ListingTable :: try_new ( config) ?;
420- self . register_table (
421- name,
422- & PyTable {
423- table : Arc :: new ( table) ,
424- } ,
425- ) ?;
421+ self . ctx . register_table ( name, Arc :: new ( table) ) ?;
426422 Ok ( ( ) )
427423 }
428424
@@ -607,8 +603,32 @@ impl PySessionContext {
607603 Ok ( df)
608604 }
609605
610- pub fn register_table ( & mut self , name : & str , table : & PyTable ) -> PyDataFusionResult < ( ) > {
611- self . ctx . register_table ( name, table. table ( ) ) ?;
606+ pub fn register_table (
607+ & mut self ,
608+ name : & str ,
609+ table_provider : Bound < ' _ , PyAny > ,
610+ ) -> PyDataFusionResult < ( ) > {
611+ let provider = if table_provider. hasattr ( "__datafusion_table_provider__" ) ? {
612+ let capsule = table_provider
613+ . getattr ( "__datafusion_table_provider__" ) ?
614+ . call0 ( ) ?;
615+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
616+ validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
617+
618+ let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
619+ let provider: ForeignTableProvider = provider. into ( ) ;
620+ Arc :: new ( provider) as Arc < dyn TableProvider >
621+ } else if let Ok ( py_table) = table_provider. extract :: < PyTable > ( ) {
622+ py_table. table ( )
623+ } else if let Ok ( py_provider) = table_provider. extract :: < PyTableProvider > ( ) {
624+ py_provider. as_table ( ) . table ( )
625+ } else {
626+ return Err ( crate :: errors:: PyDataFusionError :: Common (
627+ "Expected a Table or TableProvider." . to_string ( ) ,
628+ ) ) ;
629+ } ;
630+
631+ self . ctx . register_table ( name, provider) ?;
612632 Ok ( ( ) )
613633 }
614634
@@ -651,23 +671,8 @@ impl PySessionContext {
651671 name : & str ,
652672 provider : Bound < ' _ , PyAny > ,
653673 ) -> PyDataFusionResult < ( ) > {
654- if provider. hasattr ( "__datafusion_table_provider__" ) ? {
655- let capsule = provider. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
656- let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
657- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
658-
659- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
660- let provider: ForeignTableProvider = provider. into ( ) ;
661-
662- let _ = self . ctx . register_table ( name, Arc :: new ( provider) ) ?;
663-
664- Ok ( ( ) )
665- } else {
666- Err ( crate :: errors:: PyDataFusionError :: Common (
667- "__datafusion_table_provider__ does not exist on Table Provider object."
668- . to_string ( ) ,
669- ) )
670- }
674+ // Deprecated: use `register_table` instead
675+ self . register_table ( name, provider)
671676 }
672677
673678 pub fn register_record_batches (
0 commit comments