Skip to content

Commit 57f147b

Browse files
committed
Revert "Implement metadata-aware PySimpleScalarUDF"
This reverts commit 24549bd.
1 parent 24549bd commit 57f147b

File tree

3 files changed

+27
-185
lines changed

3 files changed

+27
-185
lines changed

python/datafusion/user_defined.py

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import Any, Callable, Optional, Protocol, Sequence, overload
25+
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

32+
if TYPE_CHECKING:
33+
_R = TypeVar("_R", bound=pa.DataType)
34+
35+
3236
class Volatility(Enum):
3337
"""Defines how stable or volatile a function is.
3438
@@ -73,40 +77,6 @@ def __str__(self) -> str:
7377
return self.name.lower()
7478

7579

76-
def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field:
77-
if isinstance(value, pa.Field):
78-
return value
79-
if isinstance(value, pa.DataType):
80-
return pa.field(default_name, value)
81-
msg = "Expected a pyarrow.DataType or pyarrow.Field"
82-
raise TypeError(msg)
83-
84-
85-
def _normalize_input_fields(
86-
values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
87-
) -> list[pa.Field]:
88-
if isinstance(values, (pa.DataType, pa.Field)):
89-
sequence: Sequence[pa.DataType | pa.Field] = [values]
90-
elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
91-
sequence = values
92-
else:
93-
msg = "input_types must be a DataType, Field, or a sequence of them"
94-
raise TypeError(msg)
95-
96-
return [
97-
_normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence)
98-
]
99-
100-
101-
def _normalize_return_field(
102-
value: pa.DataType | pa.Field,
103-
*,
104-
name: str,
105-
) -> pa.Field:
106-
default_name = f"{name}_result" if name else "result"
107-
return _normalize_field(value, default_name=default_name)
108-
109-
11080
class ScalarUDFExportable(Protocol):
11181
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
11282

@@ -123,9 +93,9 @@ class ScalarUDF:
12393
def __init__(
12494
self,
12595
name: str,
126-
func: Callable[..., Any],
127-
input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
128-
return_type: pa.DataType | pa.Field,
96+
func: Callable[..., _R],
97+
input_types: pa.DataType | list[pa.DataType],
98+
return_type: _R,
12999
volatility: Volatility | str,
130100
) -> None:
131101
"""Instantiate a scalar user-defined function (UDF).
@@ -135,10 +105,10 @@ def __init__(
135105
if hasattr(func, "__datafusion_scalar_udf__"):
136106
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
137107
return
138-
normalized_inputs = _normalize_input_fields(input_types)
139-
normalized_return = _normalize_return_field(return_type, name=name)
108+
if isinstance(input_types, pa.DataType):
109+
input_types = [input_types]
140110
self._udf = df_internal.ScalarUDF(
141-
name, func, normalized_inputs, normalized_return, str(volatility)
111+
name, func, input_types, return_type, str(volatility)
142112
)
143113

144114
def __repr__(self) -> str:
@@ -157,18 +127,18 @@ def __call__(self, *args: Expr) -> Expr:
157127
@overload
158128
@staticmethod
159129
def udf(
160-
input_types: list[pa.DataType | pa.Field],
161-
return_type: pa.DataType | pa.Field,
130+
input_types: list[pa.DataType],
131+
return_type: _R,
162132
volatility: Volatility | str,
163133
name: Optional[str] = None,
164134
) -> Callable[..., ScalarUDF]: ...
165135

166136
@overload
167137
@staticmethod
168138
def udf(
169-
func: Callable[..., Any],
170-
input_types: list[pa.DataType | pa.Field],
171-
return_type: pa.DataType | pa.Field,
139+
func: Callable[..., _R],
140+
input_types: list[pa.DataType],
141+
return_type: _R,
172142
volatility: Volatility | str,
173143
name: Optional[str] = None,
174144
) -> ScalarUDF: ...
@@ -194,11 +164,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
194164
backed ScalarUDF within a PyCapsule, you can pass this parameter
195165
and ignore the rest. They will be determined directly from the
196166
underlying function. See the online documentation for more information.
197-
input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
198-
This list must be of the same length as the number of arguments. Pass
199-
:class:`pyarrow.Field` instances to preserve extension metadata.
200-
return_type (pa.DataType | pa.Field): The return type of the function. Use a
201-
:class:`pyarrow.Field` to preserve metadata on extension arrays.
167+
input_types (list[pa.DataType]): The data types of the arguments
168+
to ``func``. This list must be of the same length as the number of
169+
arguments.
170+
return_type (_R): The data type of the return value from the function.
202171
volatility (Volatility | str): See `Volatility` for allowed values.
203172
name (Optional[str]): A descriptive name for the function.
204173

python/tests/test_udf.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -124,58 +124,3 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124124
result = df2.collect()[0].column(0)
125125

126126
assert result == pa.array([False, True, True])
127-
128-
129-
def test_uuid_extension_chain(ctx) -> None:
130-
uuid_type = pa.uuid()
131-
uuid_field = pa.field("uuid_col", uuid_type)
132-
133-
first = udf(
134-
lambda values: values,
135-
[uuid_field],
136-
uuid_field,
137-
volatility="immutable",
138-
name="uuid_identity",
139-
)
140-
141-
def ensure_extension(values: pa.Array) -> pa.Array:
142-
assert isinstance(values, pa.ExtensionArray)
143-
return values
144-
145-
second = udf(
146-
ensure_extension,
147-
[uuid_field],
148-
uuid_field,
149-
volatility="immutable",
150-
name="uuid_assert",
151-
)
152-
153-
batch = pa.RecordBatch.from_arrays(
154-
[
155-
pa.array(
156-
[
157-
"00000000-0000-0000-0000-000000000000",
158-
"00000000-0000-0000-0000-000000000001",
159-
],
160-
type=uuid_type,
161-
)
162-
],
163-
names=["uuid_col"],
164-
)
165-
166-
df = ctx.create_dataframe([[batch]])
167-
result = (
168-
df.select(second(first(column("uuid_col"))))
169-
.collect()[0]
170-
.column(0)
171-
)
172-
173-
expected = pa.array(
174-
[
175-
"00000000-0000-0000-0000-000000000000",
176-
"00000000-0000-0000-0000-000000000001",
177-
],
178-
type=uuid_type,
179-
)
180-
181-
assert result.equals(expected)

src/udf.rs

Lines changed: 7 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,13 @@ use pyo3::types::PyCapsule;
2222
use pyo3::{prelude::*, types::PyTuple};
2323

2424
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
25-
use datafusion::arrow::datatypes::{DataType, Field};
25+
use datafusion::arrow::datatypes::DataType;
2626
use datafusion::arrow::pyarrow::FromPyArrow;
2727
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2828
use datafusion::error::DataFusionError;
2929
use datafusion::logical_expr::function::ScalarFunctionImplementation;
30-
use datafusion::logical_expr::ptr_eq::PtrEq;
31-
use datafusion::logical_expr::{
32-
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
33-
Volatility,
34-
};
30+
use datafusion::logical_expr::ScalarUDF;
31+
use datafusion::logical_expr::{create_udf, ColumnarValue};
3532

3633
use crate::errors::to_datafusion_err;
3734
use crate::errors::{py_datafusion_err, PyDataFusionResult};
@@ -83,73 +80,6 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
8380
})
8481
}
8582

86-
#[derive(Debug, PartialEq, Eq, Hash)]
87-
struct PySimpleScalarUDF {
88-
name: String,
89-
signature: Signature,
90-
return_field: Arc<Field>,
91-
fun: PtrEq<ScalarFunctionImplementation>,
92-
}
93-
94-
impl PySimpleScalarUDF {
95-
fn new(
96-
name: impl Into<String>,
97-
input_fields: Vec<Field>,
98-
return_field: Field,
99-
volatility: Volatility,
100-
fun: ScalarFunctionImplementation,
101-
) -> Self {
102-
let signature_types = input_fields
103-
.into_iter()
104-
.map(|field| field.data_type().clone())
105-
.collect();
106-
let signature = Signature::exact(signature_types, volatility);
107-
Self {
108-
name: name.into(),
109-
signature,
110-
return_field: Arc::new(return_field),
111-
fun: fun.into(),
112-
}
113-
}
114-
}
115-
116-
impl ScalarUDFImpl for PySimpleScalarUDF {
117-
fn as_any(&self) -> &dyn std::any::Any {
118-
self
119-
}
120-
121-
fn name(&self) -> &str {
122-
&self.name
123-
}
124-
125-
fn signature(&self) -> &Signature {
126-
&self.signature
127-
}
128-
129-
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
130-
Ok(self.return_field.data_type().clone())
131-
}
132-
133-
fn return_field_from_args(
134-
&self,
135-
_args: ReturnFieldArgs,
136-
) -> datafusion::error::Result<Arc<Field>> {
137-
Ok(Arc::new(
138-
self.return_field
139-
.as_ref()
140-
.clone()
141-
.with_name(self.name.clone()),
142-
))
143-
}
144-
145-
fn invoke_with_args(
146-
&self,
147-
args: ScalarFunctionArgs,
148-
) -> datafusion::error::Result<ColumnarValue> {
149-
(self.fun)(&args.args)
150-
}
151-
}
152-
15383
/// Represents a PyScalarUDF
15484
#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)]
15585
#[derive(Debug, Clone)]
@@ -164,19 +94,17 @@ impl PyScalarUDF {
16494
fn new(
16595
name: &str,
16696
func: PyObject,
167-
input_types: PyArrowType<Vec<Field>>,
168-
return_type: PyArrowType<Field>,
97+
input_types: PyArrowType<Vec<DataType>>,
98+
return_type: PyArrowType<DataType>,
16999
volatility: &str,
170100
) -> PyResult<Self> {
171-
let volatility = parse_volatility(volatility)?;
172-
let scalar_impl = PySimpleScalarUDF::new(
101+
let function = create_udf(
173102
name,
174103
input_types.0,
175104
return_type.0,
176-
volatility,
105+
parse_volatility(volatility)?,
177106
to_scalar_function_impl(func),
178107
);
179-
let function = ScalarUDF::new_from_impl(scalar_impl);
180108
Ok(Self { function })
181109
}
182110

0 commit comments

Comments
 (0)