diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 6a3aa31a8609f..f9b7b72c341f8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -929,9 +929,12 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { fn propagate_constraints( &self, _interval: &Interval, - _inputs: &[&Interval], + inputs: &[&Interval], ) -> Result>> { - Ok(Some(vec![])) + // Conservative default: return inputs unchanged (no narrowing). + // The returned vec must have the same length as `inputs` to satisfy + // the interval solver contract. + Ok(Some(inputs.iter().map(|i| (*i).clone()).collect())) } /// Calculates the [`SortProperties`] of this function based on its children's properties. diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 395cb4eae03f5..e4914a89b9423 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -191,11 +191,238 @@ impl ScalarUDFImpl for CeilFunc { } fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { - let data_type = inputs[0].data_type(); - Interval::make_unbounded(&data_type) + let [input] = inputs else { + return exec_err!( + "ceil expected 1 argument for bounds evaluation, got {}", + inputs.len() + ); + }; + let data_type = input.data_type(); + match (ceil_scalar(input.lower()), ceil_scalar(input.upper())) { + (Some(lo), Some(hi)) => Interval::try_new(lo, hi) + .or_else(|_| Interval::make_unbounded(&data_type)), + _ => Interval::make_unbounded(&data_type), + } + } + + fn propagate_constraints( + &self, + interval: &Interval, + inputs: &[&Interval], + ) -> Result>> { + let [input_interval] = inputs else { + return exec_err!( + "ceil expected 1 argument for constraint propagation, got {}", + inputs.len() + ); + }; + // ceil(x) ∈ [N, M] → x ∈ (ceil(N)−1, floor(M)] + // Normalize bounds to integers ceil can actually take before mapping back. + let lo = match interval.lower() { + ScalarValue::Float64(Some(n)) if n.is_finite() => { + Some(ScalarValue::Float64(Some(n.ceil() - 1.0))) + } + ScalarValue::Float32(Some(n)) if n.is_finite() => { + Some(ScalarValue::Float32(Some(n.ceil() - 1.0))) + } + _ => None, + }; + let hi = match interval.upper() { + ScalarValue::Float64(Some(n)) if n.is_finite() => { + Some(ScalarValue::Float64(Some(n.floor()))) + } + ScalarValue::Float32(Some(n)) if n.is_finite() => { + Some(ScalarValue::Float32(Some(n.floor()))) + } + _ => None, + }; + match (lo, hi) { + (Some(lo), Some(hi)) => { + let constraint = Interval::try_new(lo, hi)?; + Ok(input_interval.intersect(constraint)?.map(|r| vec![r])) + } + _ => Ok(Some(vec![(*input_interval).clone()])), + } } fn documentation(&self) -> Option<&Documentation> { self.doc() } } + +fn ceil_scalar(v: &ScalarValue) -> Option { + match v { + ScalarValue::Float64(Some(f)) if f.is_finite() => { + Some(ScalarValue::Float64(Some(f.ceil()))) + } + ScalarValue::Float32(Some(f)) if f.is_finite() => { + Some(ScalarValue::Float32(Some(f.ceil()))) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ceil() -> CeilFunc { + CeilFunc::new() + } + + fn f64_interval(lo: f64, hi: f64) -> Interval { + Interval::try_new( + ScalarValue::Float64(Some(lo)), + ScalarValue::Float64(Some(hi)), + ) + .unwrap() + } + + fn f32_interval(lo: f32, hi: f32) -> Interval { + Interval::try_new( + ScalarValue::Float32(Some(lo)), + ScalarValue::Float32(Some(hi)), + ) + .unwrap() + } + + fn unbounded_f64() -> Interval { + Interval::make_unbounded(&DataType::Float64).unwrap() + } + + fn unbounded_f32() -> Interval { + Interval::make_unbounded(&DataType::Float32).unwrap() + } + + // --- evaluate_bounds --- + + #[test] + fn test_evaluate_bounds_basic() { + // ceil([1.2, 3.7]) = [2.0, 4.0] + let input = f64_interval(1.2, 3.7); + let result = ceil().evaluate_bounds(&[&input]).unwrap(); + assert_eq!(result, f64_interval(2.0, 4.0)); + } + + #[test] + fn test_evaluate_bounds_already_integer() { + // ceil([2.0, 4.0]) = [2.0, 4.0] + let input = f64_interval(2.0, 4.0); + let result = ceil().evaluate_bounds(&[&input]).unwrap(); + assert_eq!(result, f64_interval(2.0, 4.0)); + } + + #[test] + fn test_evaluate_bounds_f32() { + // ceil([1.1f32, 2.9f32]) = [2.0f32, 3.0f32] + let input = f32_interval(1.1, 2.9); + let result = ceil().evaluate_bounds(&[&input]).unwrap(); + assert_eq!(result, f32_interval(2.0, 3.0)); + } + + #[test] + fn test_evaluate_bounds_unbounded_returns_unbounded() { + let input = unbounded_f64(); + let result = ceil().evaluate_bounds(&[&input]).unwrap(); + assert_eq!(result, unbounded_f64()); + } + + #[test] + fn test_evaluate_bounds_negative() { + // ceil([-3.7, -1.2]) = [-3.0, -1.0] + let input = f64_interval(-3.7, -1.2); + let result = ceil().evaluate_bounds(&[&input]).unwrap(); + assert_eq!(result, f64_interval(-3.0, -1.0)); + } + + // --- propagate_constraints --- + + #[test] + fn test_propagate_constraints_basic() { + // ceil(x) ∈ [13.0, 15.0] → x ∈ (12.0, 15.0] + let output = f64_interval(13.0, 15.0); + let input = unbounded_f64(); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], f64_interval(12.0, 15.0)); + } + + #[test] + fn test_propagate_constraints_non_integer_bounds() { + // ceil(x) ∈ [12.3, 14.7] — non-integer bounds are normalized: + // lower: ceil(12.3)-1 = 13-1 = 12.0, upper: floor(14.7) = 14.0 + // → x ∈ (12.0, 14.0] + let output = f64_interval(12.3, 14.7); + let input = unbounded_f64(); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], f64_interval(12.0, 14.0)); + } + + #[test] + fn test_propagate_constraints_f32() { + // Same as basic but with Float32 + let output = f32_interval(5.0, 8.0); + let input = unbounded_f32(); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], f32_interval(4.0, 8.0)); + } + + #[test] + fn test_propagate_constraints_unbounded_output_no_change() { + // No output constraint → input unchanged + let output = unbounded_f64(); + let input = f64_interval(1.0, 10.0); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], input); + } + + #[test] + fn test_propagate_constraints_nan_output_no_change() { + // NaN bounds → conservative: input unchanged + let output = Interval::try_new( + ScalarValue::Float64(Some(f64::NAN)), + ScalarValue::Float64(Some(f64::NAN)), + ) + .unwrap(); + let input = f64_interval(0.0, 100.0); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], input); + } + + #[test] + fn test_propagate_constraints_negative_range() { + // ceil(x) ∈ [-3.0, -1.0] → x ∈ (-4.0, -1.0] + let output = f64_interval(-3.0, -1.0); + let input = unbounded_f64(); + let result = ceil() + .propagate_constraints(&output, &[&input]) + .unwrap() + .unwrap(); + assert_eq!(result[0], f64_interval(-4.0, -1.0)); + } + + #[test] + fn test_propagate_constraints_empty_intersection() { + // x ∈ [5.0, 7.0], constraint ceil(x) ∈ [20.0, 30.0] + // mapped input constraint: [19.0, 30.0] — no overlap with [5.0, 7.0] + // → intersect returns None → Ok(None) (branch pruned) + let output = f64_interval(20.0, 30.0); + let input = f64_interval(5.0, 7.0); + let result = ceil().propagate_constraints(&output, &[&input]).unwrap(); + assert!(result.is_none()); + } +} diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 1090660a6b5e6..4027f65b1cdd7 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -22,13 +22,14 @@ use std::sync::Arc; use crate::{ PhysicalExpr, expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, + scalar_function::ScalarFunctionExpr, }; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{Result, ScalarValue, internal_err}; -use datafusion_expr::Operator; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{Operator, Volatility}; /// Indicates whether interval arithmetic is supported for the given expression. /// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. @@ -56,6 +57,13 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { check_support(cast.expr(), schema) } else if let Some(negative) = expr.downcast_ref::() { check_support(negative.arg(), schema) + } else if let Some(scalar_fn) = expr.downcast_ref::() { + scalar_fn.fun().signature().volatility == Volatility::Immutable + && is_datatype_supported(scalar_fn.return_type()) + && scalar_fn + .args() + .iter() + .all(|arg| check_support(arg, schema)) } else { false } @@ -193,3 +201,143 @@ fn interval_dt_to_duration_ms(dt: &IntervalDayTime) -> Result { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{Column, Literal}; + use crate::scalar_function::ScalarFunctionExpr; + use arrow::datatypes::{Field, Schema}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + }; + + fn f64_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, false)])) + } + + fn utf8_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)])) + } + + fn col_x() -> Arc { + Arc::new(Column::new("x", 0)) + } + + fn lit_f64(v: f64) -> Arc { + Arc::new(Literal::new(ScalarValue::Float64(Some(v)))) + } + + fn scalar_fn_expr( + udf: Arc, + args: Vec>, + return_type: DataType, + ) -> Arc { + let name = udf.name().to_string(); + Arc::new(ScalarFunctionExpr::new( + &name, + udf, + args, + Field::new("result", return_type, true).into(), + Arc::new(ConfigOptions::default()), + )) + } + + /// A minimal UDF whose declared return type is Utf8, used to test that + /// check_support rejects functions with unsupported return types without + /// relying on an invalid ceil-returns-Utf8 combination. + #[derive(Debug, PartialEq, Eq, Hash)] + struct Utf8UDF { + signature: Signature, + } + + impl Utf8UDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for Utf8UDF { + fn name(&self) -> &str { + "utf8_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + fn invoke_with_args(&self, _: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + } + + fn utf8_udf() -> Arc { + Arc::new(datafusion_expr::ScalarUDF::from(Utf8UDF::new())) + } + + #[test] + fn test_check_support_scalar_fn_supported_return_type() { + // ceil(x) returns Float64 — both return type and child are supported + let schema = f64_schema(); + let expr = scalar_fn_expr( + datafusion_functions::math::ceil(), + vec![col_x()], + DataType::Float64, + ); + assert!(check_support(&expr, &schema)); + } + + #[test] + fn test_check_support_scalar_fn_unsupported_return_type() { + // utf8_udf(x) returns Utf8 — not in is_datatype_supported + let schema = f64_schema(); + let expr = scalar_fn_expr(utf8_udf(), vec![col_x()], DataType::Utf8); + assert!(!check_support(&expr, &schema)); + } + + #[test] + fn test_check_support_scalar_fn_unsupported_child() { + // ceil applied to a Utf8 column — child fails is_datatype_supported + let schema = utf8_schema(); + let col_s = Arc::new(Column::new("s", 0)) as Arc; + let expr = scalar_fn_expr( + datafusion_functions::math::ceil(), + vec![col_s], + DataType::Float64, + ); + assert!(!check_support(&expr, &schema)); + } + + #[test] + fn test_check_support_scalar_fn_in_binary_expr() { + // ceil(x) > 5.0 — the main use case: ScalarFunctionExpr inside a BinaryExpr + let schema = f64_schema(); + let ceil_x = scalar_fn_expr( + datafusion_functions::math::ceil(), + vec![col_x()], + DataType::Float64, + ); + let expr: Arc = + Arc::new(BinaryExpr::new(ceil_x, Operator::Gt, lit_f64(5.0))); + assert!(check_support(&expr, &schema)); + } + + #[test] + fn test_check_support_scalar_fn_in_binary_expr_unsupported_return() { + // utf8_udf(x) > 5.0 where f returns Utf8 — should be false + let schema = f64_schema(); + let fn_expr = scalar_fn_expr(utf8_udf(), vec![col_x()], DataType::Utf8); + let expr: Arc = + Arc::new(BinaryExpr::new(fn_expr, Operator::Gt, lit_f64(5.0))); + assert!(!check_support(&expr, &schema)); + } +} diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 418d005c971ea..b81d4254ede34 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -302,7 +302,12 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - self.fun.evaluate_bounds(children) + let result = self.fun.evaluate_bounds(children)?; + if result.data_type() == DataType::Null { + Interval::make_unbounded(self.return_type()) + } else { + Ok(result) + } } fn propagate_constraints( diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c485e181f3826..c39f05d05ab77 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -1209,6 +1209,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_filter_statistics_ceil_scalar_fn() -> Result<()> { + // Table: x Float64, min=8.0, max=16.0, 100 rows. + // Filter: ceil(x) > 12.0 + // + // The range [8.0, 16.0] lies within a single IEEE-754 binade so + // Float64 cardinality is proportional to the value range, making + // the selectivity estimate predictable. + // + // With check_support recognising ScalarFunctionExpr and CeilFunc + // implementing evaluate_bounds/propagate_constraints the solver + // narrows x to roughly [11.0, 16.0]: + // ceil(x) > 12 → x ∈ (11, 16] → conservative [11, 16] + // selectivity ≈ (16−11)/(16−8) = 5/8 = 0.625 → ~62 rows + // + // Without the fix the estimate stays at 100 (no interval analysis). + let schema = Schema::new(vec![Field::new("x", DataType::Float64, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Float64(Some(8.0))), + max_value: Precision::Inexact(ScalarValue::Float64(Some(16.0))), + ..Default::default() + }], + }, + schema.clone(), + )); + + let x = col("x", &schema)?; + let ceil_udf = datafusion_functions::math::ceil(); + let config = Arc::new(ConfigOptions::new()); + let ceil_x: Arc = + Arc::new(datafusion_physical_expr::ScalarFunctionExpr::try_new( + Arc::clone(&ceil_udf), + vec![x], + &schema, + config, + )?); + let predicate = binary(ceil_x, Operator::Gt, lit(12.0f64), &schema)?; + + let filter = Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + + let num_rows = statistics.num_rows.get_value().copied().unwrap_or(100); + // Interval analysis must narrow the estimate below the full 100-row input. + assert!( + num_rows < 100, + "expected interval analysis to narrow row estimate, got {num_rows}" + ); + // The conservative bound is x ∈ [11, 16] out of [8, 16] → ~62 rows. + // Allow a generous range to be robust to float-cardinality rounding. + assert!( + num_rows >= 50, + "expected at least 50 rows after ceil(x) > 12.0 on [8,16], got {num_rows}" + ); + Ok(()) + } + #[tokio::test] async fn test_filter_statistics_basic_expr() -> Result<()> { // Table: