diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs index 3dc143b8e5211..c84d5688a06c4 100644 --- a/datafusion/expr/src/higher_order_function.rs +++ b/datafusion/expr/src/higher_order_function.rs @@ -73,6 +73,10 @@ pub enum HigherOrderTypeSignature { VariadicAny, /// The specified number of lambdas or arguments with arbitrary types. Any(usize), + /// Exactly the specified arguments in the given order, with arbitrary types. + /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value + /// argument types. + Exact(Vec>), } /// Provides information necessary for calling a higher order function. @@ -138,6 +142,28 @@ impl HigherOrderSignature { } } + /// Exactly the specified arguments in the given order, with arbitrary types. + /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value + /// argument types. + /// + /// # Example + /// A function that takes one value argument followed by one lambda: + /// ``` + /// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility}; + /// let sig = HigherOrderSignature::exact( + /// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + /// Volatility::Immutable, + /// ); + /// ``` + pub fn exact(args: Vec>, volatility: Volatility) -> Self { + Self { + type_signature: HigherOrderTypeSignature::Exact(args), + volatility, + coerce_values_for_lambdas: false, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } + /// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas] /// should be called pub fn with_coerce_values_for_lambdas(mut self) -> Self { @@ -406,7 +432,7 @@ pub struct HigherOrderReturnFieldArgs<'a> { } /// An argument to a higher order function -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] pub enum ValueOrLambda { /// A value with associated data Value(V), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 86616daf08c73..624f0da3fd901 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -230,6 +230,78 @@ pub fn value_fields_with_higher_order_udf( Ok(current_fields.to_vec()) } + HigherOrderTypeSignature::Exact(ref expected) => { + if current_fields.len() != expected.len() { + let name = func.name(); + let expected_len = expected.len(); + let actual_len = current_fields.len(); + return plan_err!( + "The function '{name}' expected {expected_len} argument(s) but received {actual_len}" + ); + } + + for (i, (actual, expected)) in + current_fields.iter().zip(expected.iter()).enumerate() + { + match (actual, expected) { + (ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {} + (ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {} + (ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => { + let name = func.name(); + return plan_err!( + "The function '{name}' expected a lambda at position {i} but received a value" + ); + } + (ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => { + let name = func.name(); + return plan_err!( + "The function '{name}' expected a value at position {i} but received a lambda" + ); + } + } + } + + let arg_types = current_fields + .iter() + .filter_map(|p| match p { + ValueOrLambda::Value(field) => Some(field.data_type().clone()), + ValueOrLambda::Lambda(_) => None, + }) + .collect::>(); + + let coerced_types = func.coerce_value_types(&arg_types)?; + + if coerced_types.len() != arg_types.len() { + return plan_err!( + "{} coerce_value_types should have returned {} items but returned {}", + func.name(), + arg_types.len(), + coerced_types.len() + ); + } + + let mut coerced_types = coerced_types.into_iter(); + + current_fields + .iter() + .map(|current_field| match current_field { + ValueOrLambda::Value(field) => { + let data_type = coerced_types.next().ok_or_else(|| { + internal_datafusion_err!( + "coerced_types len should have been checked above" + ) + })?; + + Ok(ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(data_type), + ))) + } + ValueOrLambda::Lambda(lambda) => { + Ok(ValueOrLambda::Lambda(lambda.clone())) + } + }) + .collect() + } } } @@ -2026,4 +2098,88 @@ mod tests { "The function 'mock_higher_order_function' expected 1 arguments but received 0" ); } + + #[test] + fn test_higher_order_function_exact_signature() { + let fun = MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)], + }; + + let new_fields = value_fields_with_higher_order_udf( + &[ + ValueOrLambda::Value(Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Int32, false), + false, + ))), + ValueOrLambda::Lambda(()), + ], + &fun, + ) + .unwrap(); + + // type coercion applied: List(Int32) -> LargeList(Int32) + assert_eq!( + new_fields, + vec![ + ValueOrLambda::Value(Arc::new(Field::new_large_list( + "", + Field::new_list_field(DataType::Int32, false), + false + ))), + ValueOrLambda::Lambda(()), + ] + ) + } + + #[test] + fn test_higher_order_function_exact_signature_wrong_value_count() { + let fun = MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![], + }; + + let err = value_fields_with_higher_order_udf::<()>( + &[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())], + &fun, + ) + .unwrap_err(); + + assert_contains!( + err.to_string(), + "expected a value at position 0 but received a lambda" + ); + } + + #[test] + fn test_higher_order_function_exact_signature_wrong_lambda_count() { + let fun = MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![], + }; + + let err = value_fields_with_higher_order_udf::<()>( + &[ + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))), + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))), + ], + &fun, + ) + .unwrap_err(); + + assert_contains!( + err.to_string(), + "expected a lambda at position 1 but received a value" + ); + } } diff --git a/datafusion/functions-nested/src/array_any_match.rs b/datafusion/functions-nested/src/array_any_match.rs index dce06bb2f2d3b..3ce43a23c2124 100644 --- a/datafusion/functions-nested/src/array_any_match.rs +++ b/datafusion/functions-nested/src/array_any_match.rs @@ -81,7 +81,10 @@ impl Default for ArrayAnyMatch { impl ArrayAnyMatch { pub fn new() -> Self { Self { - signature: HigherOrderSignature::user_defined(Volatility::Immutable), + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), aliases: vec![String::from("any_match"), String::from("list_any_match")], } } @@ -117,9 +120,7 @@ impl HigherOrderUDF for ArrayAnyMatch { } fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { - let list = if arg_types.len() == 1 { - &arg_types[0] - } else { + let [list] = arg_types else { return plan_err!( "{} function requires 1 value argument, got {}", self.name(), @@ -150,15 +151,15 @@ impl HigherOrderUDF for ArrayAnyMatch { _step: usize, fields: &[ValueOrLambda>], ) -> Result { - let [list, _lambda] = take_function_args(self.name(), fields)?; - - let field = match list { - ValueOrLambda::Value(f) => match f.data_type() { - DataType::List(field) => field, - DataType::LargeList(field) => field, - other => return plan_err!("expected list, got {other}"), - }, - _ => return plan_err!("{} expected a value as first argument", self.name()), + let [list, _] = take_function_args(self.name(), fields)?; + let ValueOrLambda::Value(list) = list else { + return plan_err!("{} expects a value as first argument", self.name()); + }; + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + other => return plan_err!("expected list, got {other}"), }; Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone( @@ -170,15 +171,18 @@ impl HigherOrderUDF for ArrayAnyMatch { &self, args: HigherOrderReturnFieldArgs, ) -> Result> { - let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?; - let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable()); + let [ValueOrLambda::Value(list), _] = + take_function_args(self.name(), args.arg_fields)? + else { + return plan_err!("{} expects a value as first argument", self.name()); + }; + let nullable = list.is_nullable(); Ok(Arc::new(Field::new("", DataType::Boolean, nullable))) } fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { - let [list, lambda] = take_function_args(self.name(), &args.args)?; - - let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda) + let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] = + take_function_args(self.name(), &args.args)? else { return exec_err!("{} expects a value followed by a lambda", self.name()); }; diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index aa8e2c5b46a2b..bfec3613b6b8a 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -77,7 +77,10 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: HigherOrderSignature::user_defined(Volatility::Immutable), + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), aliases: vec![String::from("list_transform")], } } @@ -97,11 +100,9 @@ impl HigherOrderUDF for ArrayTransform { } fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { - let list = if arg_types.len() == 1 { - &arg_types[0] - } else { + let [list] = arg_types else { return plan_err!( - "{} function requires 1 value arguments, got {}", + "{} function requires 1 value argument, got {}", self.name(), arg_types.len() ); @@ -130,7 +131,10 @@ impl HigherOrderUDF for ArrayTransform { _step: usize, fields: &[ValueOrLambda>], ) -> Result { - let (list, _lambda) = value_lambda_pair(self.name(), fields)?; + let [list, _] = take_function_args(self.name(), fields)?; + let ValueOrLambda::Value(list) = list else { + return plan_err!("{} expects a value as first argument", self.name()); + }; let field = match list.data_type() { DataType::List(field) => field, @@ -149,7 +153,11 @@ impl HigherOrderUDF for ArrayTransform { &self, args: HigherOrderReturnFieldArgs, ) -> Result> { - let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?; + let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return plan_err!("{} expects a value followed by a lambda", self.name()); + }; //TODO: should metadata be copied into the transformed array? @@ -171,7 +179,11 @@ impl HigherOrderUDF for ArrayTransform { } fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { - let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; + let [list, lambda] = take_function_args(self.name(), &args.args)?; + let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda) + else { + return plan_err!("{} expects a value followed by a lambda", self.name()); + }; let list_array = list.to_array(args.number_rows)?; @@ -265,22 +277,6 @@ impl HigherOrderUDF for ArrayTransform { } } -fn value_lambda_pair<'a, V: Debug, L: Debug>( - name: &str, - args: &'a [ValueOrLambda], -) -> Result<(&'a V, &'a L)> { - let [value, lambda] = take_function_args(name, args)?; - - let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda) - else { - return plan_err!( - "{name} expects a value followed by a lambda, got {value:?} and {lambda:?}" - ); - }; - - Ok((value, lambda)) -} - #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; diff --git a/datafusion/sqllogictest/test_files/array/array_transform.slt b/datafusion/sqllogictest/test_files/array/array_transform.slt index f87253695d332..c8c43588c882c 100644 --- a/datafusion/sqllogictest/test_files/array/array_transform.slt +++ b/datafusion/sqllogictest/test_files/array/array_transform.slt @@ -396,13 +396,13 @@ physical_plan query error select array_transform(); ---- -DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0 +DataFusion error: Error during planning: The function 'array_transform' expected 2 argument(s) but received 0 query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64 select array_transform(1, v -> v*2); -query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(None\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\) +query error DataFusion error: Error during planning: The function 'array_transform' expected a value at position 0 but received a lambda select array_transform(v -> v*2, [1, 2]); query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1