Skip to content

Commit 0d7bb36

Browse files
committed
Pass codec capsule to table providers
1 parent 9a7244e commit 0d7bb36

File tree

11 files changed

+197
-135
lines changed

11 files changed

+197
-135
lines changed

examples/datafusion-ffi-example/python/tests/_test_table_function.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def test_ffi_table_function_register():
3131
ctx = SessionContext()
32-
table_func = MyTableFunction(ctx)
32+
table_func = MyTableFunction()
3333
table_udtf = udtf(table_func, "my_table_func")
3434
ctx.register_udtf(table_udtf)
3535
result = ctx.sql("select * from my_table_func()").collect()
@@ -49,7 +49,7 @@ def test_ffi_table_function_register():
4949

5050
def test_ffi_table_function_call_directly():
5151
ctx = SessionContext()
52-
table_func = MyTableFunction(ctx)
52+
table_func = MyTableFunction()
5353
table_udtf = udtf(table_func, "my_table_func")
5454

5555
my_table = table_udtf()
@@ -77,9 +77,6 @@ class PythonTableFunction:
7777
provider, and this function takes no arguments
7878
"""
7979

80-
def __init__(self, ctx: SessionContext) -> None:
81-
self._ctx = ctx
82-
8380
def __call__(
8481
self, num_cols: Expr, num_rows: Expr, num_batches: Expr
8582
) -> TableProviderExportable:
@@ -88,7 +85,7 @@ def __call__(
8885
num_rows.to_variant().value_i64(),
8986
num_batches.to_variant().value_i64(),
9087
]
91-
return MyTableProvider(self._ctx, *args)
88+
return MyTableProvider(*args)
9289

9390

9491
def common_table_function_test(test_ctx: SessionContext) -> None:
@@ -111,7 +108,7 @@ def common_table_function_test(test_ctx: SessionContext) -> None:
111108

112109
def test_python_table_function():
113110
ctx = SessionContext()
114-
table_func = PythonTableFunction(ctx)
111+
table_func = PythonTableFunction()
115112
table_udtf = udtf(table_func, "my_table_func")
116113
ctx.register_udtf(table_udtf)
117114

@@ -130,7 +127,7 @@ def my_udtf(
130127
num_rows.to_variant().value_i64(),
131128
num_batches.to_variant().value_i64(),
132129
]
133-
return MyTableProvider(ctx, *args)
130+
return MyTableProvider(*args)
134131

135132
ctx.register_udtf(my_udtf)
136133

examples/datafusion-ffi-example/python/tests/_test_table_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def test_table_loading():
2626
ctx = SessionContext()
27-
table = MyTableProvider(ctx, 3, 2, 4)
27+
table = MyTableProvider(3, 2, 4)
2828
ctx.register_table("t", table)
2929
result = ctx.table("t").collect()
3030

examples/datafusion-ffi-example/src/table_function.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ use std::sync::Arc;
2020
use datafusion_catalog::{TableFunctionImpl, TableProvider};
2121
use datafusion_common::error::Result as DataFusionResult;
2222
use datafusion_expr::Expr;
23-
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
2423
use datafusion_ffi::udtf::FFI_TableFunction;
2524
use pyo3::types::PyCapsule;
2625
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
@@ -30,37 +29,33 @@ use crate::utils::ffi_logical_codec_from_pycapsule;
3029

3130
#[pyclass(name = "MyTableFunction", module = "datafusion_ffi_example", subclass)]
3231
#[derive(Debug, Clone)]
33-
pub(crate) struct MyTableFunction {
34-
logical_codec: FFI_LogicalExtensionCodec,
35-
}
32+
pub(crate) struct MyTableFunction {}
3633

3734
#[pymethods]
3835
impl MyTableFunction {
3936
#[new]
40-
fn new(session: &Bound<PyAny>) -> PyResult<Self> {
41-
let logical_codec = ffi_logical_codec_from_pycapsule(session)?;
42-
43-
Ok(Self { logical_codec })
37+
fn new() -> Self {
38+
Self {}
4439
}
4540

4641
fn __datafusion_table_function__<'py>(
4742
&self,
4843
py: Python<'py>,
44+
session: &Bound<PyAny>,
4945
) -> PyResult<Bound<'py, PyCapsule>> {
5046
let name = cr"datafusion_table_function".into();
5147

5248
let func = self.clone();
53-
let provider =
54-
FFI_TableFunction::new_with_ffi_codec(Arc::new(func), None, self.logical_codec.clone());
49+
let codec = ffi_logical_codec_from_pycapsule(session)?;
50+
let provider = FFI_TableFunction::new_with_ffi_codec(Arc::new(func), None, codec);
5551

5652
PyCapsule::new(py, provider, Some(name))
5753
}
5854
}
5955

6056
impl TableFunctionImpl for MyTableFunction {
6157
fn call(&self, _args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
62-
let provider = MyTableProvider::new_from_ffi_session(self.logical_codec.clone(), 4, 3, 2)
63-
.create_table()?;
58+
let provider = MyTableProvider::new(4, 3, 2).create_table()?;
6459
Ok(Arc::new(provider))
6560
}
6661
}

examples/datafusion-ffi-example/src/table_provider.rs

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use arrow_array::{ArrayRef, RecordBatch};
2121
use arrow_schema::{DataType, Field, Schema};
2222
use datafusion_catalog::MemTable;
2323
use datafusion_common::error::{DataFusionError, Result as DataFusionResult};
24-
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
2524
use datafusion_ffi::table_provider::FFI_TableProvider;
2625
use pyo3::exceptions::PyRuntimeError;
2726
use pyo3::types::PyCapsule;
@@ -34,7 +33,6 @@ use crate::utils::ffi_logical_codec_from_pycapsule;
3433
#[pyclass(name = "MyTableProvider", module = "datafusion_ffi_example", subclass)]
3534
#[derive(Clone)]
3635
pub(crate) struct MyTableProvider {
37-
logical_codec: FFI_LogicalExtensionCodec,
3836
num_cols: usize,
3937
num_rows: usize,
4038
num_batches: usize,
@@ -81,55 +79,31 @@ impl MyTableProvider {
8179
}
8280
}
8381

84-
impl MyTableProvider {
85-
pub fn new_from_ffi_session(
86-
logical_codec: FFI_LogicalExtensionCodec,
87-
num_cols: usize,
88-
num_rows: usize,
89-
num_batches: usize,
90-
) -> Self {
91-
Self {
92-
logical_codec,
93-
num_cols,
94-
num_rows,
95-
num_batches,
96-
}
97-
}
98-
}
99-
10082
#[pymethods]
10183
impl MyTableProvider {
10284
#[new]
103-
pub fn new(
104-
session: &Bound<PyAny>,
105-
num_cols: usize,
106-
num_rows: usize,
107-
num_batches: usize,
108-
) -> PyResult<Self> {
109-
let logical_codec = ffi_logical_codec_from_pycapsule(session)?;
110-
Ok(Self {
111-
logical_codec,
85+
pub fn new(num_cols: usize, num_rows: usize, num_batches: usize) -> Self {
86+
Self {
11287
num_cols,
11388
num_rows,
11489
num_batches,
115-
})
90+
}
11691
}
11792

11893
pub fn __datafusion_table_provider__<'py>(
11994
&self,
12095
py: Python<'py>,
96+
session: &Bound<PyAny>,
12197
) -> PyResult<Bound<'py, PyCapsule>> {
12298
let name = cr"datafusion_table_provider".into();
12399

124100
let provider = self
125101
.create_table()
126102
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
127-
let provider = FFI_TableProvider::new_with_ffi_codec(
128-
Arc::new(provider),
129-
false,
130-
None,
131-
self.logical_codec.clone(),
132-
);
103+
104+
let codec = ffi_logical_codec_from_pycapsule(session)?;
105+
let provider =
106+
FFI_TableProvider::new_with_ffi_codec(Arc::new(provider), false, None, codec);
133107

134108
PyCapsule::new(py, provider, Some(name))
135109
}

python/datafusion/catalog.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
if TYPE_CHECKING:
2828
import pyarrow as pa
2929

30-
from datafusion import DataFrame
30+
from datafusion import DataFrame, SessionContext
3131
from datafusion.context import TableProviderExportable
3232

3333
try:
@@ -65,9 +65,9 @@ def schema_names(self) -> set[str]:
6565
return self.catalog.schema_names()
6666

6767
@staticmethod
68-
def memory_catalog() -> Catalog:
68+
def memory_catalog(ctx: SessionContext | None = None) -> Catalog:
6969
"""Create an in-memory catalog provider."""
70-
catalog = df_internal.catalog.RawCatalog.memory_catalog()
70+
catalog = df_internal.catalog.RawCatalog.memory_catalog(ctx)
7171
return Catalog(catalog)
7272

7373
def schema(self, name: str = "public") -> Schema:
@@ -112,9 +112,9 @@ def __repr__(self) -> str:
112112
return self._raw_schema.__repr__()
113113

114114
@staticmethod
115-
def memory_schema() -> Schema:
115+
def memory_schema(ctx: SessionContext | None = None) -> Schema:
116116
"""Create an in-memory schema provider."""
117-
schema = df_internal.catalog.RawSchema.memory_schema()
117+
schema = df_internal.catalog.RawSchema.memory_schema(ctx)
118118
return Schema(schema)
119119

120120
def names(self) -> set[str]:
@@ -163,10 +163,12 @@ class Table:
163163
__slots__ = ("_inner",)
164164

165165
def __init__(
166-
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
166+
self,
167+
table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset,
168+
ctx: SessionContext | None = None,
167169
) -> None:
168170
"""Constructor."""
169-
self._inner = df_internal.catalog.RawTable(table)
171+
self._inner = df_internal.catalog.RawTable(table, ctx)
170172

171173
def __repr__(self) -> str:
172174
"""Print a string representation of the table."""

python/datafusion/user_defined.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
30+
from datafusion import SessionContext
3031
from datafusion.expr import Expr
3132

3233
if TYPE_CHECKING:
@@ -923,16 +924,14 @@ class TableFunction:
923924
"""
924925

925926
def __init__(
926-
self,
927-
name: str,
928-
func: Callable[[], any],
927+
self, name: str, func: Callable[[], any], ctx: SessionContext | None = None
929928
) -> None:
930929
"""Instantiate a user-defined table function (UDTF).
931930
932931
See :py:func:`udtf` for a convenience function and argument
933932
descriptions.
934933
"""
935-
self._udtf = df_internal.TableFunction(name, func)
934+
self._udtf = df_internal.TableFunction(name, func, ctx)
936935

937936
def __call__(self, *args: Expr) -> Any:
938937
"""Execute the UDTF and return a table provider."""

0 commit comments

Comments
 (0)