Skip to content

Commit f8c42af

Browse files
committed
Initial commit of udtf work
1 parent 7d8bcd8 commit f8c42af

File tree

6 files changed

+178
-2
lines changed

6 files changed

+178
-2
lines changed

python/datafusion/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,16 @@
5151
from .io import read_avro, read_csv, read_json, read_parquet
5252
from .plan import ExecutionPlan, LogicalPlan
5353
from .record_batch import RecordBatch, RecordBatchStream
54-
from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF, udaf, udf, udwf
54+
from .udf import (
55+
Accumulator,
56+
AggregateUDF,
57+
ScalarUDF,
58+
TableFunction,
59+
WindowUDF,
60+
udaf,
61+
udf,
62+
udwf,
63+
)
5564

5665
__version__ = importlib_metadata.version(__name__)
5766

@@ -74,6 +83,7 @@
7483
"SessionConfig",
7584
"SessionContext",
7685
"Table",
86+
"TableFunction",
7787
"WindowFrame",
7888
"WindowUDF",
7989
"col",

python/datafusion/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from datafusion.dataframe import DataFrame
3131
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3232
from datafusion.record_batch import RecordBatchStream
33-
from datafusion.udf import AggregateUDF, ScalarUDF, WindowUDF
33+
from datafusion.udf import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3434

3535
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
3636
from ._internal import SessionConfig as SessionConfigInternal
@@ -752,6 +752,10 @@ def register_table_provider(
752752
"""
753753
self.ctx.register_table_provider(name, provider)
754754

755+
def register_udtf(self, name: str, func: TableFunction) -> None:
756+
"""Register a user defined table function."""
757+
self.ctx.register_udtf(name, func._udtf)
758+
755759
def register_record_batches(
756760
self, name: str, partitions: list[list[pa.RecordBatch]]
757761
) -> None:

python/datafusion/udf.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,72 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
760760
return decorator
761761

762762

763+
class TableFunction:
764+
"""Class for performing user-defined table functions (UDTF).
765+
766+
Table functions generate new table providers based on the
767+
input expressions.
768+
"""
769+
770+
def __init__(
771+
self,
772+
name: str,
773+
func: Callable[[], any],
774+
) -> None:
775+
"""Instantiate a user-defined table function (UDTF).
776+
777+
See :py:func:`udtf` for a convenience function and argument
778+
descriptions.
779+
"""
780+
self._udtf = df_internal.user_defined.TableFunction(name, func)
781+
782+
def __call__(self, *args: Expr) -> Any:
783+
"""Execute the UDTF and return a table provider."""
784+
args_raw = [arg.expr for arg in args]
785+
return Expr(self._udtf.__call__(*args_raw))
786+
787+
@overload
788+
@staticmethod
789+
def udtf(
790+
name: str,
791+
) -> Callable[..., Any]: ...
792+
793+
@overload
794+
@staticmethod
795+
def udtf(
796+
func: Callable[[], Any],
797+
name: str,
798+
) -> TableFunction: ...
799+
800+
@staticmethod
801+
def udtf(*args: Any, **kwargs: Any):
802+
"""Create a new User-Defined Table Function (UDTF)."""
803+
if args and callable(args[0]):
804+
# Case 1: Used as a function, require the first parameter to be callable
805+
return TableFunction._create_table_udf(*args, **kwargs)
806+
# Case 2: Used as a decorator with parameters
807+
return TableFunction._create_table_udf_decorator(*args, **kwargs)
808+
809+
@staticmethod
810+
def _create_table_udf(
811+
func: Callable[..., Any],
812+
name: str,
813+
) -> TableFunction:
814+
"""Create a TableFunction instance from function arguments."""
815+
if not callable(func):
816+
msg = "`func` must be callable."
817+
raise TypeError(msg)
818+
819+
return TableFunction(name, func)
820+
821+
def __repr__(self) -> str:
822+
"""User printable representation."""
823+
return self._udtf.__repr__()
824+
825+
763826
# Convenience exports so we can import instead of treating as
764827
# variables at the package root
765828
udf = ScalarUDF.udf
766829
udaf = AggregateUDF.udaf
767830
udwf = WindowUDF.udwf
831+
udtf = TableFunction.udtf

src/context.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
4444
use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
46+
use crate::udtf::PyTableFunction;
4647
use crate::udwf::PyWindowUDF;
4748
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
4849
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
@@ -390,6 +391,12 @@ impl PySessionContext {
390391
Ok(())
391392
}
392393

394+
pub fn register_udtf(&mut self, func: PyTableFunction) {
395+
let name = func.name.clone();
396+
let func = Arc::new(func);
397+
self.ctx.register_udtf(&name, func);
398+
}
399+
393400
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
394401
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
395402
let result = self.ctx.sql(query);

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pub mod substrait;
6060
mod udaf;
6161
#[allow(clippy::borrow_deref_ref)]
6262
mod udf;
63+
pub mod udtf;
6364
mod udwf;
6465
pub mod utils;
6566

@@ -88,6 +89,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
8889
m.add_class::<udf::PyScalarUDF>()?;
8990
m.add_class::<udaf::PyAggregateUDF>()?;
9091
m.add_class::<udwf::PyWindowUDF>()?;
92+
m.add_class::<udtf::PyTableFunction>()?;
9193
m.add_class::<config::PyConfig>()?;
9294
m.add_class::<sql::logical::PyLogicalPlan>()?;
9395
m.add_class::<physical_plan::PyExecutionPlan>()?;

src/udtf.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use pyo3::prelude::*;
21+
22+
use crate::dataframe::PyTableProvider;
23+
use crate::errors::py_datafusion_err;
24+
use crate::expr::PyExpr;
25+
use crate::utils::validate_pycapsule;
26+
use datafusion::catalog::{TableFunctionImpl, TableProvider};
27+
use datafusion::common::exec_err;
28+
use datafusion::logical_expr::Expr;
29+
use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
30+
use pyo3::types::PyCapsule;
31+
32+
/// Represents a user defined table function
33+
#[pyclass(name = "TableFunction", module = "datafusion")]
34+
#[derive(Debug, Clone)]
35+
pub struct PyTableFunction {
36+
pub(crate) name: String,
37+
pub(crate) inner: PyTableFunctionInner,
38+
}
39+
40+
// TODO: Implement pure python based user defined table functions
41+
#[derive(Debug, Clone)]
42+
pub(crate) enum PyTableFunctionInner {
43+
// PythonFunction(Arc<PyObject>),
44+
FFIFunction(Arc<dyn TableFunctionImpl>),
45+
}
46+
47+
#[pymethods]
48+
impl PyTableFunction {
49+
#[new]
50+
#[pyo3(signature=(name, func))]
51+
pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
52+
if func.hasattr("__datafusion_table_function__")? {
53+
let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
54+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
55+
validate_pycapsule(capsule, "datafusion_table_function")?;
56+
57+
let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
58+
let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
59+
60+
Ok(Self {
61+
name: name.to_string(),
62+
inner: PyTableFunctionInner::FFIFunction(Arc::new(foreign_func)),
63+
})
64+
} else {
65+
exec_err!("Python based Table Functions are not yet implemented")
66+
.map_err(py_datafusion_err)
67+
}
68+
}
69+
70+
#[pyo3(signature = (*args))]
71+
pub fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyTableProvider> {
72+
let args: Vec<Expr> = args.iter().map(|e| e.expr.clone()).collect();
73+
let table_provider = self.call(&args).map_err(py_datafusion_err)?;
74+
75+
Ok(PyTableProvider::new(table_provider))
76+
}
77+
78+
fn __repr__(&self) -> PyResult<String> {
79+
Ok(format!("TableUDF({})", self.name))
80+
}
81+
}
82+
83+
impl TableFunctionImpl for PyTableFunction {
84+
fn call(&self, args: &[Expr]) -> datafusion::common::Result<Arc<dyn TableProvider>> {
85+
match &self.inner {
86+
PyTableFunctionInner::FFIFunction(func) => func.call(args),
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)