diff --git a/crates/core/src/udtf.rs b/crates/core/src/udtf.rs index b3de25e52..3a244a417 100644 --- a/crates/core/src/udtf.rs +++ b/crates/core/src/udtf.rs @@ -18,20 +18,33 @@ use std::ptr::NonNull; use std::sync::Arc; -use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; -use datafusion::error::Result as DataFusionResult; +use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider}; +use datafusion::error::{DataFusionError, Result as DataFusionResult}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionState; use datafusion::logical_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyImportError, PyTypeError}; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType}; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; use crate::table::PyTable; +/// A pure-Python UDTF callable plus the metadata we discovered about it +/// at registration time. +#[derive(Debug, Clone)] +pub(crate) struct PythonTableFunctionCallable { + pub(crate) callable: Arc>, + /// Whether the callable's signature accepts a ``session`` keyword + /// argument (or ``**kwargs``). When true the calling + /// :class:`SessionContext` is threaded through on each invocation. + pub(crate) accepts_session: bool, +} + /// Represents a user defined table function #[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")] #[derive(Debug, Clone)] @@ -40,21 +53,21 @@ pub struct PyTableFunction { pub(crate) inner: PyTableFunctionInner, } -// TODO: Implement pure python based user defined table functions #[derive(Debug, Clone)] pub(crate) enum PyTableFunctionInner { - PythonFunction(Arc>), + PythonFunction(PythonTableFunctionCallable), FFIFunction(Arc), } #[pymethods] impl PyTableFunction { #[new] - #[pyo3(signature=(name, func, session))] + #[pyo3(signature=(name, func, session, accepts_session=false))] pub fn new( name: &str, func: Bound<'_, PyAny>, session: Option>, + accepts_session: bool, ) -> PyResult { let inner = if func.hasattr("__datafusion_table_function__")? { let py = func.py(); @@ -80,8 +93,10 @@ impl PyTableFunction { PyTableFunctionInner::FFIFunction(foreign_func) } else { - let py_obj = Arc::new(func.unbind()); - PyTableFunctionInner::PythonFunction(py_obj) + PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable { + callable: Arc::new(func.unbind()), + accepts_session, + }) }; Ok(Self { @@ -107,20 +122,59 @@ impl PyTableFunction { } } +/// Materialize a fresh :class:`PySessionContext` from the borrowed +/// ``&dyn Session`` handed in at call time. +/// +/// Upstream invokes ``call_with_args`` with a trait-object reference +/// rather than an owned context; we downcast it to the canonical +/// :class:`SessionState` impl and rebuild a :class:`SessionContext` +/// (sharing the same registries via the Arc-heavy interior of +/// :class:`SessionState`). Returns an error if the trait object is a +/// non-:class:`SessionState` implementation (e.g. a foreign FFI +/// session) — those are not exposed to Python today. +fn py_session_from_session(session: &dyn Session) -> DataFusionResult { + let state = session + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution( + "Cannot expose this UDTF's calling session to Python: \ + the session is not a SessionState. Drop the `session` \ + keyword from the callback signature to fall back to the \ + expression-only call form." + .to_string(), + ) + })?; + Ok(PySessionContext::from(SessionContext::new_with_state( + state.clone(), + ))) +} + #[allow(clippy::result_large_err)] fn call_python_table_function( - func: &Arc>, - args: &[Expr], + func: &PythonTableFunctionCallable, + args: TableFunctionArgs, ) -> DataFusionResult> { - let args = args + let py_session = if func.accepts_session { + Some(py_session_from_session(args.session())?) + } else { + None + }; + let py_exprs = args + .exprs() .iter() .map(|arg| PyExpr::from(arg.clone())) .collect::>(); - // move |args: &[ArrayRef]| -> Result { Python::attach(|py| { - let py_args = PyTuple::new(py, args)?; - let provider_obj = func.call1(py, py_args)?; + let py_args = PyTuple::new(py, py_exprs)?; + let provider_obj = if let Some(session) = py_session { + let kwargs = PyDict::new(py); + kwargs.set_item("session", session.into_pyobject(py)?)?; + func.callable.call(py, py_args, Some(&kwargs))? + } else { + func.callable.call1(py, py_args)? + }; let provider = provider_obj.bind(py).clone(); Ok::, PyErr>(PyTable::new(provider, None)?.table) @@ -132,8 +186,8 @@ impl TableFunctionImpl for PyTableFunction { fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult> { match &self.inner { PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args), - PyTableFunctionInner::PythonFunction(obj) => { - call_python_table_function(obj, args.exprs()) + PyTableFunctionInner::PythonFunction(callable) => { + call_python_table_function(callable, args) } } } diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eb50a094..c524ac4e1 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -1054,6 +1054,47 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: ) +def _callable_accepts_session_kwarg(func: object) -> bool: + """Return True if ``func`` accepts a ``session`` keyword argument. + + Used to opt a Python UDTF callback into receiving the calling + :class:`SessionContext` at invocation time. ``**kwargs`` callables + are treated as accepting it; built-ins and objects without an + introspectable signature fall back to ``False``. + """ + import inspect # noqa: PLC0415 + + try: + signature = inspect.signature(func) + except (TypeError, ValueError): + return False + + for parameter in signature.parameters.values(): + if parameter.name == "session": + return True + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + return True + return False + + +def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]: + """Adapt the raw internal session pyo3 object back to a Python wrapper. + + The Rust call site forwards a ``datafusion._internal.SessionContext``, + but UDTF authors expect to interact with the public + :class:`datafusion.SessionContext` wrapper. This closure wraps the + internal object once per call before delegating to ``func``. + """ + + @functools.wraps(func, updated=()) + def adapter(*args: Any, session: Any, **kwargs: Any) -> Any: + wrapped = SessionContext.__new__(SessionContext) + wrapped.ctx = session + return func(*args, session=wrapped, **kwargs) + + return adapter + + class TableFunction: """Class for performing user-defined table functions (UDTF). @@ -1066,10 +1107,19 @@ def __init__( ) -> None: """Instantiate a user-defined table function (UDTF). + If ``func``'s signature accepts a ``session`` keyword (or + ``**kwargs``), the calling :class:`SessionContext` is threaded + through to it on each invocation. Use it inside the body to look + up registered tables, UDFs, or session configuration. Callables + whose signatures do not declare ``session`` are invoked with the + positional expression arguments only. + See :py:func:`udtf` for a convenience function and argument descriptions. """ - self._udtf = df_internal.TableFunction(name, func, ctx) + accepts_session = _callable_accepts_session_kwarg(func) + registered = _wrap_session_kwarg_for_udtf(func) if accepts_session else func + self._udtf = df_internal.TableFunction(name, registered, ctx, accepts_session) def __call__(self, *args: Expr) -> Any: """Execute the UDTF and return a table provider.""" diff --git a/python/tests/test_udtf.py b/python/tests/test_udtf.py index 925a8ba01..7a1b128bf 100644 --- a/python/tests/test_udtf.py +++ b/python/tests/test_udtf.py @@ -134,3 +134,68 @@ def string_arg_func(prefix: Expr) -> TableProviderExportable: result = ctx.sql("SELECT * FROM string_arg_func('test')").collect() assert len(result) == 1 assert result[0].schema.names == ["test_a", "test_b"] + + +def test_python_table_function_receives_session() -> None: + """A UDTF whose signature declares ``session`` gets the calling ctx.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + @udtf("session_aware_func") + def session_aware_func(*, session: SessionContext) -> TableProviderExportable: + captured.append(session) + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(session_aware_func) + result = ctx.sql("SELECT * FROM session_aware_func()").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + # Sharing the same catalog confirms the wrapper points at the caller's state. + assert captured[0].catalog().schema().names() == ctx.catalog().schema().names() + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_python_table_function_session_used_for_metadata() -> None: + """The UDTF can inspect session state through the passed-in context.""" + ctx = SessionContext() + base_batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) + ctx.register_batch("base_tbl", base_batch) + + seen_tables: list[set[str]] = [] + + @udtf("table_inventory") + def table_inventory(*, session: SessionContext) -> TableProviderExportable: + # Stash the visible tables to verify the session wired through. + seen_tables.append(session.catalog().schema().names()) + batch = pa.RecordBatch.from_pydict({"name": ["base_tbl"]}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(table_inventory) + result = ctx.sql("SELECT * FROM table_inventory()").collect() + + assert seen_tables == [{"base_tbl"}] + assert result[0].column(0).to_pylist() == ["base_tbl"] + + +def test_python_table_function_class_callable_session_kwarg() -> None: + """Class-based UDTFs whose __call__ accepts ``session`` get it too.""" + ctx = SessionContext() + captured: list[SessionContext] = [] + + class SessionAware: + def __call__( + self, n: Expr, *, session: SessionContext + ) -> TableProviderExportable: + captured.append(session) + count = n.to_variant().value_i64() + batch = pa.RecordBatch.from_pydict({"a": list(range(count))}) + return Table(ds.dataset([batch])) + + ctx.register_udtf(udtf(SessionAware(), "session_class_func")) + result = ctx.sql("SELECT * FROM session_class_func(3)").collect() + + assert len(captured) == 1 + assert isinstance(captured[0], SessionContext) + assert result[0].column(0).to_pylist() == [0, 1, 2]