Skip to content

Commit a771702

Browse files
committed
feat: enhance table registration to support TableProvider in catalog and context
1 parent 8c413f5 commit a771702

File tree

5 files changed

+70
-44
lines changed

5 files changed

+70
-44
lines changed

python/datafusion/catalog.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
if TYPE_CHECKING:
2828
import pyarrow as pa
2929

30+
from datafusion import TableProvider
31+
3032
try:
3133
from warnings import deprecated # Python 3.13+
3234
except ImportError:
@@ -122,8 +124,8 @@ def table(self, name: str) -> Table:
122124
"""Return the table with the given ``name`` from this schema."""
123125
return Table(self._raw_schema.table(name))
124126

125-
def register_table(self, name, table) -> None:
126-
"""Register a table provider in this schema."""
127+
def register_table(self, name, table: Table | TableProvider) -> None:
128+
"""Register a table or table provider in this schema."""
127129
if isinstance(table, Table):
128130
return self._raw_schema.register_table(name, table.table)
129131
return self._raw_schema.register_table(name, table)
@@ -219,8 +221,8 @@ def table(self, name: str) -> Table | None:
219221
"""Retrieve a specific table from this schema."""
220222
...
221223

222-
def register_table(self, name: str, table: Table) -> None: # noqa: B027
223-
"""Add a table from this schema.
224+
def register_table(self, name: str, table: Table | TableProvider) -> None: # noqa: B027
225+
"""Add a table to this schema.
224226
225227
This method is optional. If your schema provides a fixed list of tables, you do
226228
not need to implement this method.

python/datafusion/context.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import pandas as pd
4747
import polars as pl
4848

49+
from datafusion import TableProvider
4950
from datafusion.plan import ExecutionPlan, LogicalPlan
5051

5152

@@ -743,16 +744,21 @@ def register_view(self, name: str, df: DataFrame) -> None:
743744
view = df.into_view()
744745
self.ctx.register_table(name, view)
745746

746-
def register_table(self, name: str, table: Table) -> None:
747-
"""Register a :py:class: `~datafusion.catalog.Table` as a table.
747+
def register_table(self, name: str, table: Table | TableProvider) -> None:
748+
"""Register a :py:class:`~datafusion.catalog.Table` or ``TableProvider``.
748749
749-
The registered table can be referenced from SQL statement executed against.
750+
The registered table can be referenced from SQL statements executed against
751+
this context.
750752
751753
Args:
752754
name: Name of the resultant table.
753-
table: DataFusion table to add to the session context.
755+
table: DataFusion :class:`Table` or :class:`TableProvider` to add to the
756+
session context.
754757
"""
755-
self.ctx.register_table(name, table.table)
758+
if isinstance(table, Table):
759+
self.ctx.register_table(name, table.table)
760+
else:
761+
self.ctx.register_table(name, table)
756762

757763
def deregister_table(self, name: str) -> None:
758764
"""Remove a table from the session."""
@@ -772,14 +778,18 @@ def register_catalog_provider(
772778
self.ctx.register_catalog_provider(name, provider)
773779

774780
def register_table_provider(
775-
self, name: str, provider: TableProviderExportable
781+
self, name: str, provider: TableProviderExportable | TableProvider
776782
) -> None:
777783
"""Register a table provider.
778784
779-
This table provider must have a method called ``__datafusion_table_provider__``
780-
which returns a PyCapsule that exposes a ``FFI_TableProvider``.
785+
Deprecated: use :meth:`register_table` instead.
781786
"""
782-
self.ctx.register_table_provider(name, provider)
787+
warnings.warn(
788+
"register_table_provider is deprecated; use register_table",
789+
DeprecationWarning,
790+
stacklevel=2,
791+
)
792+
self.register_table(name, provider)
783793

784794
def register_udtf(self, func: TableFunction) -> None:
785795
"""Register a user defined table function."""

src/catalog.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::dataframe::PyTableProvider;
1819
use crate::dataset::Dataset;
1920
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2021
use crate::utils::{validate_pycapsule, wait_for_future};
@@ -209,11 +210,14 @@ impl PySchema {
209210
} else {
210211
match table_provider.extract::<PyTable>() {
211212
Ok(py_table) => py_table.table,
212-
Err(_) => {
213-
let py = table_provider.py();
214-
let provider = Dataset::new(&table_provider, py)?;
215-
Arc::new(provider) as Arc<dyn TableProvider>
216-
}
213+
Err(_) => match table_provider.extract::<PyTableProvider>() {
214+
Ok(py_provider) => py_provider.as_table().table(),
215+
Err(_) => {
216+
let py = table_provider.py();
217+
let provider = Dataset::new(&table_provider, py)?;
218+
Arc::new(provider) as Arc<dyn TableProvider>
219+
}
220+
},
217221
}
218222
};
219223

@@ -305,7 +309,7 @@ impl RustWrappedPySchemaProvider {
305309
}
306310

307311
if py_table.hasattr("__datafusion_table_provider__")? {
308-
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
312+
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
309313
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
310314
validate_pycapsule(capsule, "datafusion_table_provider")?;
311315

@@ -320,6 +324,10 @@ impl RustWrappedPySchemaProvider {
320324
}
321325
}
322326

327+
if let Ok(py_provider) = py_table.extract::<PyTableProvider>() {
328+
return Ok(Some(py_provider.as_table().table()));
329+
}
330+
323331
match py_table.extract::<PyTable>() {
324332
Ok(py_table) => Ok(Some(py_table.table)),
325333
Err(_) => {

src/context.rs

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use pyo3::prelude::*;
3333

3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
36+
use crate::dataframe::PyTableProvider;
3637
use crate::dataset::Dataset;
3738
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3839
use crate::expr::sort_expr::PySortExpr;
@@ -417,12 +418,7 @@ impl PySessionContext {
417418
.with_listing_options(options)
418419
.with_schema(resolved_schema);
419420
let table = ListingTable::try_new(config)?;
420-
self.register_table(
421-
name,
422-
&PyTable {
423-
table: Arc::new(table),
424-
},
425-
)?;
421+
self.ctx.register_table(name, Arc::new(table))?;
426422
Ok(())
427423
}
428424

@@ -607,8 +603,32 @@ impl PySessionContext {
607603
Ok(df)
608604
}
609605

610-
pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
611-
self.ctx.register_table(name, table.table())?;
606+
pub fn register_table(
607+
&mut self,
608+
name: &str,
609+
table_provider: Bound<'_, PyAny>,
610+
) -> PyDataFusionResult<()> {
611+
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
612+
let capsule = table_provider
613+
.getattr("__datafusion_table_provider__")?
614+
.call0()?;
615+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
616+
validate_pycapsule(capsule, "datafusion_table_provider")?;
617+
618+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
619+
let provider: ForeignTableProvider = provider.into();
620+
Arc::new(provider) as Arc<dyn TableProvider>
621+
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
622+
py_table.table()
623+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
624+
py_provider.as_table().table()
625+
} else {
626+
return Err(crate::errors::PyDataFusionError::Common(
627+
"Expected a Table or TableProvider.".to_string(),
628+
));
629+
};
630+
631+
self.ctx.register_table(name, provider)?;
612632
Ok(())
613633
}
614634

@@ -651,23 +671,8 @@ impl PySessionContext {
651671
name: &str,
652672
provider: Bound<'_, PyAny>,
653673
) -> PyDataFusionResult<()> {
654-
if provider.hasattr("__datafusion_table_provider__")? {
655-
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
656-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
657-
validate_pycapsule(capsule, "datafusion_table_provider")?;
658-
659-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
660-
let provider: ForeignTableProvider = provider.into();
661-
662-
let _ = self.ctx.register_table(name, Arc::new(provider))?;
663-
664-
Ok(())
665-
} else {
666-
Err(crate::errors::PyDataFusionError::Common(
667-
"__datafusion_table_provider__ does not exist on Table Provider object."
668-
.to_string(),
669-
))
670-
}
674+
// Deprecated: use `register_table` instead
675+
self.register_table(name, provider)
671676
}
672677

673678
pub fn register_record_batches(

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
8989
m.add_class::<dataframe::PyDataFrame>()?;
9090
m.add_class::<dataframe::PyParquetColumnOptions>()?;
9191
m.add_class::<dataframe::PyParquetWriterOptions>()?;
92+
m.add_class::<dataframe::PyTableProvider>()?;
9293
m.add_class::<udf::PyScalarUDF>()?;
9394
m.add_class::<udaf::PyAggregateUDF>()?;
9495
m.add_class::<udwf::PyWindowUDF>()?;

0 commit comments

Comments
 (0)