Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion datafusion/expr/src/higher_order_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ pub enum HigherOrderTypeSignature {
VariadicAny,
/// The specified number of lambdas or arguments with arbitrary types.
Any(usize),
/// Exactly the specified arguments in the given order, with arbitrary types.
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
/// argument types.
Exact(Vec<ValueOrLambda<(), ()>>),
}

/// Provides information necessary for calling a higher order function.
Expand Down Expand Up @@ -138,6 +142,28 @@ impl HigherOrderSignature {
}
}

/// Exactly the specified arguments in the given order, with arbitrary types.
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
/// argument types.
///
/// # Example
/// A function that takes one value argument followed by one lambda:
/// ```
/// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility};
/// let sig = HigherOrderSignature::exact(
/// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
/// Volatility::Immutable,
/// );
/// ```
pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::Exact(args),
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}

/// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas]
/// should be called
pub fn with_coerce_values_for_lambdas(mut self) -> Self {
Expand Down Expand Up @@ -406,7 +432,7 @@ pub struct HigherOrderReturnFieldArgs<'a> {
}

/// An argument to a higher order function
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub enum ValueOrLambda<V, L> {
/// A value with associated data
Value(V),
Expand Down
156 changes: 156 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,78 @@ pub fn value_fields_with_higher_order_udf<L: Clone>(

Ok(current_fields.to_vec())
}
HigherOrderTypeSignature::Exact(ref expected) => {
if current_fields.len() != expected.len() {
let name = func.name();
let expected_len = expected.len();
let actual_len = current_fields.len();
return plan_err!(
"The function '{name}' expected {expected_len} argument(s) but received {actual_len}"
);
}

for (i, (actual, expected)) in
current_fields.iter().zip(expected.iter()).enumerate()
{
match (actual, expected) {
(ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {}
(ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {}
(ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => {
let name = func.name();
return plan_err!(
"The function '{name}' expected a lambda at position {i} but received a value"
);
}
(ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => {
let name = func.name();
return plan_err!(
"The function '{name}' expected a value at position {i} but received a lambda"
);
}
}
}

let arg_types = current_fields
.iter()
.filter_map(|p| match p {
ValueOrLambda::Value(field) => Some(field.data_type().clone()),
ValueOrLambda::Lambda(_) => None,
})
.collect::<Vec<_>>();

let coerced_types = func.coerce_value_types(&arg_types)?;

if coerced_types.len() != arg_types.len() {
return plan_err!(
"{} coerce_value_types should have returned {} items but returned {}",
func.name(),
arg_types.len(),
coerced_types.len()
);
}

let mut coerced_types = coerced_types.into_iter();

current_fields
.iter()
.map(|current_field| match current_field {
ValueOrLambda::Value(field) => {
let data_type = coerced_types.next().ok_or_else(|| {
internal_datafusion_err!(
"coerced_types len should have been checked above"
)
})?;

Ok(ValueOrLambda::Value(Arc::new(
field.as_ref().clone().with_data_type(data_type),
)))
}
ValueOrLambda::Lambda(lambda) => {
Ok(ValueOrLambda::Lambda(lambda.clone()))
}
})
.collect()
}
}
}

Expand Down Expand Up @@ -2026,4 +2098,88 @@ mod tests {
"The function 'mock_higher_order_function' expected 1 arguments but received 0"
);
}

#[test]
fn test_higher_order_function_exact_signature() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
Volatility::Immutable,
),
coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
};

let new_fields = value_fields_with_higher_order_udf(
&[
ValueOrLambda::Value(Arc::new(Field::new_list(
"",
Field::new_list_field(DataType::Int32, false),
false,
))),
ValueOrLambda::Lambda(()),
],
&fun,
)
.unwrap();

// type coercion applied: List(Int32) -> LargeList(Int32)
assert_eq!(
new_fields,
vec![
ValueOrLambda::Value(Arc::new(Field::new_large_list(
"",
Field::new_list_field(DataType::Int32, false),
false
))),
ValueOrLambda::Lambda(()),
]
)
}

#[test]
fn test_higher_order_function_exact_signature_wrong_value_count() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
Volatility::Immutable,
),
coerced_value_types: vec![],
};

let err = value_fields_with_higher_order_udf::<()>(
&[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())],
&fun,
)
.unwrap_err();

assert_contains!(
err.to_string(),
"expected a value at position 0 but received a lambda"
);
}

#[test]
fn test_higher_order_function_exact_signature_wrong_lambda_count() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
Volatility::Immutable,
),
coerced_value_types: vec![],
};

let err = value_fields_with_higher_order_udf::<()>(
&[
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
],
&fun,
)
.unwrap_err();

assert_contains!(
err.to_string(),
"expected a lambda at position 1 but received a value"
);
}
}
40 changes: 22 additions & 18 deletions datafusion/functions-nested/src/array_any_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ impl Default for ArrayAnyMatch {
impl ArrayAnyMatch {
pub fn new() -> Self {
Self {
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
signature: HigherOrderSignature::exact(
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
Volatility::Immutable,
),
aliases: vec![String::from("any_match"), String::from("list_any_match")],
}
}
Expand Down Expand Up @@ -117,9 +120,7 @@ impl HigherOrderUDF for ArrayAnyMatch {
}

fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let list = if arg_types.len() == 1 {
&arg_types[0]
} else {
let [list] = arg_types else {
return plan_err!(
"{} function requires 1 value argument, got {}",
self.name(),
Expand Down Expand Up @@ -150,15 +151,15 @@ impl HigherOrderUDF for ArrayAnyMatch {
_step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
let [list, _lambda] = take_function_args(self.name(), fields)?;

let field = match list {
ValueOrLambda::Value(f) => match f.data_type() {
DataType::List(field) => field,
DataType::LargeList(field) => field,
other => return plan_err!("expected list, got {other}"),
},
_ => return plan_err!("{} expected a value as first argument", self.name()),
let [list, _] = take_function_args(self.name(), fields)?;
let ValueOrLambda::Value(list) = list else {
return plan_err!("{} expects a value as first argument", self.name());
};

let field = match list.data_type() {
DataType::List(field) => field,
DataType::LargeList(field) => field,
other => return plan_err!("expected list, got {other}"),
};

Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone(
Expand All @@ -170,15 +171,18 @@ impl HigherOrderUDF for ArrayAnyMatch {
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<Arc<Field>> {
let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?;
let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable());
let [ValueOrLambda::Value(list), _] =
take_function_args(self.name(), args.arg_fields)?
else {
return plan_err!("{} expects a value as first argument", self.name());
};
let nullable = list.is_nullable();
Ok(Arc::new(Field::new("", DataType::Boolean, nullable)))
}

fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
let [list, lambda] = take_function_args(self.name(), &args.args)?;

let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
take_function_args(self.name(), &args.args)?
else {
return exec_err!("{} expects a value followed by a lambda", self.name());
};
Expand Down
44 changes: 20 additions & 24 deletions datafusion/functions-nested/src/array_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ impl Default for ArrayTransform {
impl ArrayTransform {
pub fn new() -> Self {
Self {
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
signature: HigherOrderSignature::exact(
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
Volatility::Immutable,
),
aliases: vec![String::from("list_transform")],
}
}
Expand All @@ -97,11 +100,9 @@ impl HigherOrderUDF for ArrayTransform {
}

fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let list = if arg_types.len() == 1 {
&arg_types[0]
} else {
let [list] = arg_types else {
return plan_err!(
"{} function requires 1 value arguments, got {}",
"{} function requires 1 value argument, got {}",
self.name(),
arg_types.len()
);
Expand Down Expand Up @@ -130,7 +131,10 @@ impl HigherOrderUDF for ArrayTransform {
_step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
let (list, _lambda) = value_lambda_pair(self.name(), fields)?;
let [list, _] = take_function_args(self.name(), fields)?;
let ValueOrLambda::Value(list) = list else {
return plan_err!("{} expects a value as first argument", self.name());
};

let field = match list.data_type() {
DataType::List(field) => field,
Expand All @@ -149,7 +153,11 @@ impl HigherOrderUDF for ArrayTransform {
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<Arc<Field>> {
let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?;
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
take_function_args(self.name(), args.arg_fields)?
else {
return plan_err!("{} expects a value followed by a lambda", self.name());
};

//TODO: should metadata be copied into the transformed array?

Expand All @@ -171,7 +179,11 @@ impl HigherOrderUDF for ArrayTransform {
}

fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
let (list, lambda) = value_lambda_pair(self.name(), &args.args)?;
let [list, lambda] = take_function_args(self.name(), &args.args)?;
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
else {
return plan_err!("{} expects a value followed by a lambda", self.name());
};

let list_array = list.to_array(args.number_rows)?;

Expand Down Expand Up @@ -265,22 +277,6 @@ impl HigherOrderUDF for ArrayTransform {
}
}

fn value_lambda_pair<'a, V: Debug, L: Debug>(
name: &str,
args: &'a [ValueOrLambda<V, L>],
) -> Result<(&'a V, &'a L)> {
let [value, lambda] = take_function_args(name, args)?;

let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda)
else {
return plan_err!(
"{name} expects a value followed by a lambda, got {value:?} and {lambda:?}"
);
};

Ok((value, lambda))
}

#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/array/array_transform.slt
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,13 @@ physical_plan
query error
select array_transform();
----
DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0
DataFusion error: Error during planning: The function 'array_transform' expected 2 argument(s) but received 0


query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64
select array_transform(1, v -> v*2);

query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(None\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)
query error DataFusion error: Error during planning: The function 'array_transform' expected a value at position 0 but received a lambda
select array_transform(v -> v*2, [1, 2]);

query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1
Expand Down
Loading