diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c47316bccc7c1..86c5aac3d0192 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2038,6 +2038,53 @@ impl TreeNodeRewriter for Simplifier<'_> { Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) } } + // For case: + // date_part('YEAR', expr) IN (literal1, literal2, ...) + Expr::InList(InList { + expr, + list, + negated, + }) => { + if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + } + + let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) = + if negated { (NotEq, and) } else { (Eq, or) }; + + let mut rewritten: Option = None; + for item in &list { + let PreimageResult::Range { interval, expr } = + get_preimage(expr.as_ref(), item, info)? + else { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + }; + + let range_expr = rewrite_with_preimage(*interval, op, expr)?.data; + rewritten = Some(match rewritten { + None => range_expr, + Some(acc) => combiner(acc, range_expr), + }); + } + + if let Some(rewritten) = rewritten { + Transformed::yes(rewritten) + } else { + Transformed::no(Expr::InList(InList { + expr, + list, + negated, + })) + } + } // no additional rewrites possible expr => Transformed::no(expr), diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs index e0837196ca990..11c80c62e1b3f 100644 --- a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -76,7 +76,7 @@ mod test { use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, and, binary_expr, col, lit, preimage::PreimageResult, + Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult, simplify::SimplifyContext, }; @@ -165,6 +165,15 @@ mod test { )?), }) } + Expr::Literal(ScalarValue::Int32(Some(600)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(300)), + ScalarValue::Int32(Some(400)), + )?), + }) + } _ => Ok(PreimageResult::None), } } @@ -312,6 +321,38 @@ mod test { assert_eq!(optimize_test(expr, &schema), expected); } + #[test] + fn test_preimage_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false); + let expected = or( + and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))), + and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_not_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true); + let expected = and( + or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))), + or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_long_list_no_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false); + + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + #[test] fn test_preimage_non_literal_rhs_no_rewrite() { // Non-literal RHS should not be rewritten.