From c981498c84068fac3f99bd4aa1129e3fe10041d6 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:02:59 -0800 Subject: [PATCH 1/4] feat: Binary CASE WHEN expression with support for nested conditions (#13) * feat: implement binary CASE WHEN expression with support for nested conditions --- vortex-array/Cargo.toml | 5 + vortex-array/benches/expr/case_when_bench.rs | 107 +++ vortex-array/src/expr/exprs/case_when.rs | 667 +++++++++++++++++++ vortex-array/src/expr/exprs/mod.rs | 2 + vortex-array/src/expr/session.rs | 2 + vortex-datafusion/src/convert/exprs.rs | 110 ++- vortex-proto/proto/expr.proto | 6 + vortex-proto/src/generated/vortex.expr.rs | 8 + 8 files changed, 902 insertions(+), 5 deletions(-) create mode 100644 vortex-array/benches/expr/case_when_bench.rs create mode 100644 vortex-array/src/expr/exprs/case_when.rs diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index d71d8eadd81..fae0d10c6c2 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -129,6 +129,11 @@ name = "expr_large_struct_pack" path = "benches/expr/large_struct_pack.rs" harness = false +[[bench]] +name = "expr_case_when" +path = "benches/expr/case_when_bench.rs" +harness = false + [[bench]] name = "chunked_dict_builder" harness = false diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs new file mode 100644 index 00000000000..95c4c97a2b9 --- /dev/null +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] +#![allow(clippy::cast_possible_truncation)] + +use divan::Bencher; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::StructArray; +use vortex_array::expr::case_when; +use vortex_array::expr::get_item; +use vortex_array::expr::gt; +use vortex_array::expr::lit; +use vortex_array::expr::nested_case_when; +use vortex_array::expr::root; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_dtype::FieldNames; + +fn main() { + divan::main(); +} + +fn make_struct_array(size: usize) -> ArrayRef { + let data: Buffer = (0..size as i32).collect(); + let field = data.into_array(); + StructArray::try_new( + FieldNames::from(["value"]), + vec![field], + size, + Validity::NonNullable, + ) + .unwrap() + .into_array() +} + +/// Benchmark a simple binary CASE WHEN with varying array sizes. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_simple(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 500 THEN 100 ELSE 0 END + let expr = case_when( + gt(get_item("value", root()), lit(500i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); +} + +/// Benchmark nested CASE WHEN with multiple conditions. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nested_3_conditions(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 750 THEN 3 WHEN value > 500 THEN 2 WHEN value > 250 THEN 1 ELSE 0 END + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(750i32)), lit(3i32)), + (gt(get_item("value", root()), lit(500i32)), lit(2i32)), + (gt(get_item("value", root()), lit(250i32)), lit(1i32)), + ], + Some(lit(0i32)), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); +} + +/// Benchmark CASE WHEN where all conditions are true (short-circuit path). +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_all_true(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value >= 0 THEN 100 ELSE 0 END (always true for our data) + let expr = case_when( + gt(get_item("value", root()), lit(-1i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); +} + +/// Benchmark CASE WHEN where all conditions are false (short-circuit path). +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_all_false(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // CASE WHEN value > 1000000 THEN 100 ELSE 0 END (always false for our data) + let expr = case_when( + gt(get_item("value", root()), lit(1_000_000i32)), + lit(100i32), + lit(0i32), + ); + + bencher + .with_inputs(|| (&expr, &array)) + .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); +} diff --git a/vortex-array/src/expr/exprs/case_when.rs b/vortex-array/src/expr/exprs/case_when.rs new file mode 100644 index 00000000000..319210137a6 --- /dev/null +++ b/vortex-array/src/expr/exprs/case_when.rs @@ -0,0 +1,667 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Binary CASE WHEN expression for conditional value selection. +//! +//! This expression is a simple wrapper around the `zip` compute function: +//! `CASE WHEN condition THEN value ELSE else_value END` +//! +//! For n-ary CASE WHEN expressions (multiple WHEN clauses), use the +//! [`nested_case_when`] convenience function which converts to nested binary expressions: +//! `CASE WHEN a THEN x WHEN b THEN y ELSE z END` becomes +//! `CASE WHEN a THEN x ELSE (CASE WHEN b THEN y ELSE z END) END` + +use std::fmt; +use std::fmt::Formatter; +use std::hash::Hash; + +use prost::Message; +use vortex_dtype::DType; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_panic; +use vortex_proto::expr as pb; +use vortex_scalar::Scalar; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::BoolArray; +use crate::arrays::ConstantArray; +use crate::compute::zip; +use crate::expr::Arity; +use crate::expr::ChildName; +use crate::expr::ExecutionArgs; +use crate::expr::ExecutionResult; +use crate::expr::ExprId; +use crate::expr::VTable; +use crate::expr::VTableExt; +use crate::expr::expression::Expression; + +/// Options for the binary CaseWhen expression. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct CaseWhenOptions { + /// Whether an ELSE clause is present. + /// If false, unmatched rows return NULL. + pub has_else: bool, +} + +impl fmt::Display for CaseWhenOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "case_when(else={})", self.has_else) + } +} + +/// A binary CASE WHEN expression. +/// +/// This is a simple conditional select: `CASE WHEN cond THEN value ELSE else_value END` +/// which is equivalent to `zip(value, else_value, cond)`. +/// +/// Children are always in order: [condition, then_value, else_value?] +pub struct CaseWhen; + +impl VTable for CaseWhen { + type Options = CaseWhenOptions; + + fn id(&self) -> ExprId { + ExprId::from("vortex.case_when") + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + Ok(Some( + pb::CaseWhenOpts { + // For backwards compatibility, binary is num_when_then_pairs=1 + num_when_then_pairs: 1, + has_else: options.has_else, + } + .encode_to_vec(), + )) + } + + fn deserialize(&self, metadata: &[u8]) -> VortexResult { + let opts = pb::CaseWhenOpts::decode(metadata)?; + // We only support binary (1 when/then pair) now + if opts.num_when_then_pairs != 1 { + vortex_bail!( + "CaseWhen only supports binary form (1 when/then pair), got {}", + opts.num_when_then_pairs + ); + } + Ok(CaseWhenOptions { + has_else: opts.has_else, + }) + } + + fn arity(&self, options: &Self::Options) -> Arity { + // Binary: condition + then + optional else + let num_children = 2 + if options.has_else { 1 } else { 0 }; + Arity::Exact(num_children) + } + + fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("when"), + 1 => ChildName::from("then"), + 2 if options.has_else => ChildName::from("else"), + _ => unreachable!("Invalid child index {} for binary CaseWhen", child_idx), + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "CASE WHEN {} THEN {}", expr.child(0), expr.child(1))?; + if options.has_else { + write!(f, " ELSE {}", expr.child(2))?; + } + write!(f, " END") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + // The return dtype is based on the THEN expression (index 1) + let then_dtype = &arg_dtypes[1]; + + // If there's no ELSE, the result is always nullable (unmatched rows are NULL) + if !options.has_else { + Ok(then_dtype.as_nullable()) + } else { + Ok(then_dtype.clone()) + } + } + + fn execute( + &self, + _options: &Self::Options, + args: ExecutionArgs, + ) -> VortexResult { + let row_count = args.row_count; + + // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value] + let (condition, then_value, else_value) = match args.inputs.len() { + 2 => { + let [condition, then_value]: [ArrayRef; 2] = args + .inputs + .try_into() + .map_err(|_| vortex_error::vortex_err!("Expected 2 inputs"))?; + (condition, then_value, None) + } + 3 => { + let [condition, then_value, else_value]: [ArrayRef; 3] = args + .inputs + .try_into() + .map_err(|_| vortex_error::vortex_err!("Expected 3 inputs"))?; + (condition, then_value, Some(else_value)) + } + n => vortex_bail!("CaseWhen expects 2 or 3 inputs, got {}", n), + }; + + // Execute condition to get a BoolArray + let cond_bool = condition.execute::(args.ctx)?; + // SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch) + let mask = cond_bool.to_mask_fill_null_false(); + + // Short-circuit: all true -> just return THEN value + if mask.all_true() { + return then_value.execute::(args.ctx); + } + + // Short-circuit: all false -> return ELSE value or NULL + if mask.all_false() { + return match else_value { + Some(else_value) => else_value.execute::(args.ctx), + None => { + // Create NULL constant of appropriate type + let then_dtype = then_value.dtype().as_nullable(); + Ok(ExecutionResult::constant( + Scalar::null(then_dtype), + row_count, + )) + } + }; + } + + // Get else value for zip (create NULL constant if no else clause) + let else_value = else_value.unwrap_or_else(|| { + let then_dtype = then_value.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }); + + // Use zip to select: where mask is true, take then_value; else take else_value + let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?; + + result.execute::(args.ctx) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + // CaseWhen is null-sensitive because NULL conditions are treated as false + true + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Creates a binary CASE WHEN expression with an ELSE clause. +/// +/// # Arguments +/// - `condition`: Boolean expression for the WHEN clause +/// - `then_value`: Value to return when condition is true +/// - `else_value`: Value to return when condition is false +/// +/// # Example +/// ```ignore +/// // CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END +/// case_when(gt(col("x"), lit(0)), lit("positive"), lit("non-positive")) +/// ``` +pub fn case_when( + condition: Expression, + then_value: Expression, + else_value: Expression, +) -> Expression { + let options = CaseWhenOptions { has_else: true }; + CaseWhen.new_expr(options, [condition, then_value, else_value]) +} + +/// Creates a binary CASE WHEN expression without an ELSE clause. +/// +/// Returns NULL when the condition is false. +/// +/// # Arguments +/// - `condition`: Boolean expression for the WHEN clause +/// - `then_value`: Value to return when condition is true +/// +/// # Example +/// ```ignore +/// // CASE WHEN x > 0 THEN 'positive' END +/// case_when_no_else(gt(col("x"), lit(0)), lit("positive")) +/// ``` +pub fn case_when_no_else(condition: Expression, then_value: Expression) -> Expression { + let options = CaseWhenOptions { has_else: false }; + CaseWhen.new_expr(options, [condition, then_value]) +} + +/// Creates a nested CASE WHEN expression from multiple WHEN/THEN pairs. +/// +/// This is a convenience function that converts n-ary CASE WHEN to nested binary expressions: +/// `CASE WHEN a THEN x WHEN b THEN y ELSE z END` becomes +/// `CASE WHEN a THEN x ELSE (CASE WHEN b THEN y ELSE z END) END` +/// +/// # Arguments +/// - `when_then_pairs`: Vec of (condition, value) pairs +/// - `else_value`: Optional else expression (if None, unmatched rows return NULL) +/// +/// # Example +/// ```ignore +/// // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END +/// nested_case_when( +/// vec![ +/// (gt(col("x"), lit(10)), lit("high")), +/// (gt(col("x"), lit(5)), lit("medium")), +/// ], +/// Some(lit("low")), +/// ) +/// ``` +pub fn nested_case_when( + when_then_pairs: Vec<(Expression, Expression)>, + else_value: Option, +) -> Expression { + assert!( + !when_then_pairs.is_empty(), + "nested_case_when requires at least one when/then pair" + ); + + // Build from right to left (innermost first) using rfold + when_then_pairs + .into_iter() + .rfold(else_value, |acc, (condition, then_value)| { + Some(match acc { + Some(else_expr) => case_when(condition, then_value, else_expr), + None => case_when_no_else(condition, then_value), + }) + }) + .unwrap_or_else(|| vortex_panic!("rfold on non-empty iterator always produces Some")) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_dtype::DType; + use vortex_dtype::Nullability; + use vortex_dtype::PType; + use vortex_error::VortexExpect as _; + use vortex_scalar::Scalar; + + use super::*; + use crate::IntoArray; + use crate::ToCanonical; + use crate::arrays::BoolArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::StructArray; + use crate::expr::exprs::binary::eq; + use crate::expr::exprs::binary::gt; + use crate::expr::exprs::get_item::col; + use crate::expr::exprs::get_item::get_item; + use crate::expr::exprs::literal::lit; + use crate::expr::exprs::root::root; + use crate::expr::test_harness; + + // ==================== Serialization Tests ==================== + + #[test] + fn test_serialization_roundtrip() { + let options = CaseWhenOptions { has_else: true }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen.deserialize(&serialized).unwrap(); + assert_eq!(options, deserialized); + } + + #[test] + fn test_serialization_no_else() { + let options = CaseWhenOptions { has_else: false }; + let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); + let deserialized = CaseWhen.deserialize(&serialized).unwrap(); + assert_eq!(options, deserialized); + } + + // ==================== Display Tests ==================== + + #[test] + fn test_display_with_else() { + let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32)); + let display = format!("{}", expr); + assert!(display.contains("CASE")); + assert!(display.contains("WHEN")); + assert!(display.contains("THEN")); + assert!(display.contains("ELSE")); + assert!(display.contains("END")); + } + + #[test] + fn test_display_no_else() { + let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32)); + let display = format!("{}", expr); + assert!(display.contains("CASE")); + assert!(display.contains("WHEN")); + assert!(display.contains("THEN")); + assert!(!display.contains("ELSE")); + assert!(display.contains("END")); + } + + #[test] + fn test_display_nested_nary() { + // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END + // Becomes nested: CASE WHEN x>10 THEN 'high' ELSE (CASE WHEN x>5 THEN 'medium' ELSE 'low' END) END + let expr = nested_case_when( + vec![ + (gt(col("x"), lit(10i32)), lit("high")), + (gt(col("x"), lit(5i32)), lit("medium")), + ], + Some(lit("low")), + ); + let display = format!("{}", expr); + // Should contain nested CASE statements + assert_eq!(display.matches("CASE").count(), 2); + assert_eq!(display.matches("WHEN").count(), 2); + assert_eq!(display.matches("THEN").count(), 2); + } + + // ==================== DType Tests ==================== + + #[test] + fn test_return_dtype_with_else() { + let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result_dtype = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::NonNullable) + ); + } + + #[test] + fn test_return_dtype_without_else_is_nullable() { + let expr = case_when_no_else(lit(true), lit(100i32)); + let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let result_dtype = expr.return_dtype(&input_dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::Nullable) + ); + } + + #[test] + fn test_return_dtype_with_struct_input() { + let dtype = test_harness::struct_dtype(); + let expr = case_when( + gt(get_item("col1", root()), lit(10u16)), + lit(100i32), + lit(0i32), + ); + let result_dtype = expr.return_dtype(&dtype).unwrap(); + assert_eq!( + result_dtype, + DType::Primitive(PType::I32, Nullability::NonNullable) + ); + } + + // ==================== Arity Tests ==================== + + #[test] + fn test_arity_with_else() { + let options = CaseWhenOptions { has_else: true }; + assert_eq!(CaseWhen.arity(&options), Arity::Exact(3)); + } + + #[test] + fn test_arity_without_else() { + let options = CaseWhenOptions { has_else: false }; + assert_eq!(CaseWhen.arity(&options), Arity::Exact(2)); + } + + // ==================== Child Name Tests ==================== + + #[test] + fn test_child_names() { + let options = CaseWhenOptions { has_else: true }; + assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when"); + assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then"); + assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else"); + } + + // ==================== Expression Manipulation Tests ==================== + + #[test] + fn test_replace_children() { + let expr = case_when(lit(true), lit(1i32), lit(0i32)); + expr.with_children([lit(false), lit(2i32), lit(3i32)]) + .vortex_expect("operation should succeed in test"); + } + + // ==================== Evaluate Tests ==================== + + #[test] + fn test_evaluate_simple_condition() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + lit(100i32), + lit(0i32), + ); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + } + + #[test] + fn test_evaluate_nary_multiple_conditions() { + // Test n-ary via nested_case_when + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = nested_case_when( + vec![ + (eq(get_item("value", root()), lit(1i32)), lit(10i32)), + (eq(get_item("value", root()), lit(3i32)), lit(30i32)), + ], + Some(lit(0i32)), + ); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); + } + + #[test] + fn test_evaluate_nary_first_match_wins() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + // Both conditions match for values > 3, but first one wins + let expr = nested_case_when( + vec![ + (gt(get_item("value", root()), lit(2i32)), lit(100i32)), + (gt(get_item("value", root()), lit(3i32)), lit(200i32)), + ], + Some(lit(0i32)), + ); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); + } + + #[test] + fn test_evaluate_no_else_returns_null() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32)); + + let result = expr.evaluate(&test_array).unwrap(); + assert!(result.dtype().is_nullable()); + + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(1).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(2).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(3).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(4).unwrap(), + Scalar::from(100i32).cast(result.dtype()).unwrap() + ); + } + + #[test] + fn test_evaluate_all_conditions_false() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(100i32)), + lit(1i32), + lit(0i32), + ); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); + } + + #[test] + fn test_evaluate_all_conditions_true() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(0i32)), + lit(100i32), + lit(0i32), + ); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + } + + #[test] + fn test_evaluate_with_literal_condition() { + let test_array = buffer![1i32, 2, 3].into_array(); + let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let result = expr.evaluate(&test_array).unwrap(); + + if let Some(constant) = result.as_constant() { + assert_eq!(constant, Scalar::from(100i32)); + } else { + let prim = result.to_primitive(); + assert_eq!(prim.as_slice::(), &[100, 100, 100]); + } + } + + #[test] + fn test_evaluate_with_bool_column_result() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + lit(true), + lit(false), + ); + + let result = expr.evaluate(&test_array).unwrap().to_bool(); + assert_eq!( + result.bit_buffer().iter().collect::>(), + vec![false, false, true, true, true] + ); + } + + #[test] + fn test_evaluate_with_nullable_condition() { + let test_array = StructArray::from_fields(&[( + "cond", + BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(), + )]) + .unwrap() + .into_array(); + + let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); + } + + #[test] + fn test_evaluate_with_nullable_result_values() { + let test_array = StructArray::from_fields(&[ + ("value", buffer![1i32, 2, 3, 4, 5].into_array()), + ( + "result", + PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)]) + .into_array(), + ), + ]) + .unwrap() + .into_array(); + + let expr = case_when( + gt(get_item("value", root()), lit(2i32)), + get_item("result", root()), + lit(0i32), + ); + + let result = expr.evaluate(&test_array).unwrap(); + let prim = result.to_primitive(); + assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); + } + + #[test] + fn test_evaluate_with_all_null_condition() { + let test_array = StructArray::from_fields(&[( + "cond", + BoolArray::from_iter([None, None, None]).into_array(), + )]) + .unwrap() + .into_array(); + + let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + + let result = expr.evaluate(&test_array).unwrap().to_primitive(); + assert_eq!(result.as_slice::(), &[0, 0, 0]); + } + + // Note: Direct execute tests are covered through evaluate tests above, + // since evaluate() calls execute() internally. + + // Note: The binary CASE WHEN implementation using `zip` does NOT provide + // short-circuit/lazy evaluation. All child expressions are evaluated first, + // then zip selects the result based on the condition. This means expressions + // like divide-by-zero will still fail even if protected by a condition. + // This is intentional - lazy evaluation would require a more complex + // implementation that filters the input before evaluating children. +} diff --git a/vortex-array/src/expr/exprs/mod.rs b/vortex-array/src/expr/exprs/mod.rs index 145d225bcae..fa3e8766f31 100644 --- a/vortex-array/src/expr/exprs/mod.rs +++ b/vortex-array/src/expr/exprs/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod between; pub(crate) mod binary; +pub(crate) mod case_when; pub(crate) mod cast; pub(crate) mod dynamic; pub(crate) mod get_item; @@ -19,6 +20,7 @@ pub(crate) mod root; pub(crate) mod select; pub use between::*; pub use binary::*; +pub use case_when::*; pub use cast::*; pub use dynamic::*; pub use get_item::*; diff --git a/vortex-array/src/expr/session.rs b/vortex-array/src/expr/session.rs index 13106c354c4..f46546481d7 100644 --- a/vortex-array/src/expr/session.rs +++ b/vortex-array/src/expr/session.rs @@ -8,6 +8,7 @@ use vortex_session::registry::Registry; use crate::expr::ExprVTable; use crate::expr::exprs::between::Between; use crate::expr::exprs::binary::Binary; +use crate::expr::exprs::case_when::CaseWhen; use crate::expr::exprs::cast::Cast; use crate::expr::exprs::get_item::GetItem; use crate::expr::exprs::is_null::IsNull; @@ -55,6 +56,7 @@ impl Default for ExprSession { for expr in [ ExprVTable::new_static(&Between), ExprVTable::new_static(&Binary), + ExprVTable::new_static(&CaseWhen), ExprVTable::new_static(&Cast), ExprVTable::new_static(&GetItem), ExprVTable::new_static(&IsNull), diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index 4ac557f00fd..cfb8503810a 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -34,6 +34,7 @@ use vortex::expr::get_item; use vortex::expr::is_null; use vortex::expr::list_contains; use vortex::expr::lit; +use vortex::expr::nested_case_when; use vortex::expr::not; use vortex::expr::pack; use vortex::expr::root; @@ -114,6 +115,45 @@ impl DefaultExpressionConvertor { scalar_fn.name() )) } + + /// Attempts to convert a DataFusion CaseExpr to a Vortex expression. + fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult { + // DataFusion CaseExpr has: + // - expr(): Optional base expression (for "CASE expr WHEN ..." form) + // - when_then_expr(): Vec of (when, then) pairs + // - else_expr(): Optional else expression + + // We don't support the "CASE expr WHEN value1 THEN result1" form yet + if case_expr.expr().is_some() { + return Err(exec_datafusion_err!( + "CASE expr WHEN form is not yet supported, only searched CASE is supported" + )); + } + + let when_then_pairs = case_expr.when_then_expr(); + if when_then_pairs.is_empty() { + return Err(exec_datafusion_err!( + "CASE expression must have at least one WHEN clause" + )); + } + + // Convert all when/then pairs to (condition, value) tuples + let mut pairs = Vec::with_capacity(when_then_pairs.len()); + for (when_expr, then_expr) in when_then_pairs { + let condition = self.convert(when_expr.as_ref())?; + let value = self.convert(then_expr.as_ref())?; + pairs.push((condition, value)); + } + + // Convert optional else expression + let else_value = case_expr + .else_expr() + .map(|e| self.convert(e.as_ref())) + .transpose()?; + + // Use nested_case_when which converts to nested binary case_when expressions + Ok(nested_case_when(pairs, else_value)) + } } impl ExpressionConvertor for DefaultExpressionConvertor { @@ -205,6 +245,10 @@ impl ExpressionConvertor for DefaultExpressionConvertor { return self.try_convert_scalar_function(scalar_fn); } + if let Some(case_expr) = df.as_any().downcast_ref::() { + return self.try_convert_case_expr(case_expr); + } + Err(exec_datafusion_err!( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )) @@ -350,10 +394,12 @@ fn can_be_pushed_down_impl(df_expr: &Arc, schema: &Schema) -> && can_be_pushed_down_impl(like.pattern(), schema) } else if let Some(lit) = expr.downcast_ref::() { supported_data_types(&lit.value().data_type()) - } else if expr.downcast_ref::().is_some() - || expr.downcast_ref::().is_some() - { - true + } else if let Some(cast_expr) = expr.downcast_ref::() { + // CastExpr child must be an expression type that convert() can handle + is_convertible_expr(cast_expr.expr()) + } else if let Some(cast_col_expr) = expr.downcast_ref::() { + // CastColumnExpr child must be an expression type that convert() can handle + is_convertible_expr(cast_col_expr.expr()) } else if let Some(is_null) = expr.downcast_ref::() { can_be_pushed_down_impl(is_null.arg(), schema) } else if let Some(is_not_null) = expr.downcast_ref::() { @@ -366,12 +412,39 @@ fn can_be_pushed_down_impl(df_expr: &Arc, schema: &Schema) -> .all(|e| can_be_pushed_down_impl(e, schema)) } else if let Some(scalar_fn) = expr.downcast_ref::() { can_scalar_fn_be_pushed_down(scalar_fn) + } else if let Some(case_expr) = expr.downcast_ref::() { + can_case_be_pushed_down(case_expr, schema) } else { tracing::debug!(%df_expr, "DataFusion expression can't be pushed down"); false } } +/// Checks if an expression type is one that convert() can handle. +/// This is less restrictive than can_be_pushed_down since it only checks +/// expression types, not data type support. +fn is_convertible_expr(df_expr: &Arc) -> bool { + let expr = df_expr.as_any(); + + // Expression types that convert() handles + expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr + .downcast_ref::() + .is_some_and(|e| is_convertible_expr(e.expr())) + || expr + .downcast_ref::() + .is_some_and(|e| is_convertible_expr(e.expr())) + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() + || expr + .downcast_ref::() + .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::(sf).is_some()) +} + fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool { let is_op_supported = try_operator_from_df(binary.op()).is_ok(); is_op_supported @@ -379,6 +452,32 @@ fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> b && can_be_pushed_down_impl(binary.right(), schema) } +fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool { + // We only support the "searched CASE" form (CASE WHEN cond THEN result ...) + // not the "simple CASE" form (CASE expr WHEN value THEN result ...) + if case_expr.expr().is_some() { + return false; + } + + // Check all when/then pairs + for (when_expr, then_expr) in case_expr.when_then_expr() { + if !can_be_pushed_down_impl(when_expr, schema) + || !can_be_pushed_down_impl(then_expr, schema) + { + return false; + } + } + + // Check the optional else clause + if let Some(else_expr) = case_expr.else_expr() + && !can_be_pushed_down_impl(else_expr, schema) + { + return false; + } + + true +} + fn supported_data_types(dt: &DataType) -> bool { use DataType::*; @@ -412,7 +511,8 @@ fn supported_data_types(dt: &DataType) -> bool { is_supported } -/// Checks if a GetField scalar function can be pushed down. +/// Checks if a scalar function can be pushed down. +/// Currently only GetFieldFunc is supported. fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool { ScalarFunctionExpr::try_downcast_func::(scalar_fn).is_some() } diff --git a/vortex-proto/proto/expr.proto b/vortex-proto/proto/expr.proto index 4540bce0a63..7d713a3afb3 100644 --- a/vortex-proto/proto/expr.proto +++ b/vortex-proto/proto/expr.proto @@ -80,3 +80,9 @@ message SelectOpts { FieldNames exclude = 2; } } + +// Options for `vortex.case_when` +message CaseWhenOpts { + uint32 num_when_then_pairs = 1; + bool has_else = 2; +} diff --git a/vortex-proto/src/generated/vortex.expr.rs b/vortex-proto/src/generated/vortex.expr.rs index f3b6d2cf624..180e693f269 100644 --- a/vortex-proto/src/generated/vortex.expr.rs +++ b/vortex-proto/src/generated/vortex.expr.rs @@ -145,3 +145,11 @@ pub mod select_opts { Exclude(super::FieldNames), } } +/// Options for `vortex.case_when` +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct CaseWhenOpts { + #[prost(uint32, tag = "1")] + pub num_when_then_pairs: u32, + #[prost(bool, tag = "2")] + pub has_else: bool, +} From 5d0079e1074cc96f861a0d4709fe0d1a7a6dbe52 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:56:28 -0800 Subject: [PATCH 2/4] CASE WHEN expression with execution context and DataFusion equivalence tests (#17) --- vortex-array/benches/expr/case_when_bench.rs | 45 +++++++++- vortex-array/src/expr/exprs/case_when.rs | 47 +++++++--- vortex-datafusion/src/convert/exprs.rs | 92 ++++++++++++++++++++ 3 files changed, 166 insertions(+), 18 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index 95c4c97a2b9..183b56c178c 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -4,9 +4,13 @@ #![allow(clippy::unwrap_used)] #![allow(clippy::cast_possible_truncation)] +use std::sync::LazyLock; + use divan::Bencher; use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::StructArray; use vortex_array::expr::case_when; use vortex_array::expr::get_item; @@ -14,9 +18,14 @@ use vortex_array::expr::gt; use vortex_array::expr::lit; use vortex_array::expr::nested_case_when; use vortex_array::expr::root; +use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_dtype::FieldNames; +use vortex_session::VortexSession; + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); fn main() { divan::main(); @@ -49,7 +58,14 @@ fn case_when_simple(bencher: Bencher, size: usize) { bencher .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } /// Benchmark nested CASE WHEN with multiple conditions. @@ -69,7 +85,14 @@ fn case_when_nested_3_conditions(bencher: Bencher, size: usize) { bencher .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } /// Benchmark CASE WHEN where all conditions are true (short-circuit path). @@ -86,7 +109,14 @@ fn case_when_all_true(bencher: Bencher, size: usize) { bencher .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } /// Benchmark CASE WHEN where all conditions are false (short-circuit path). @@ -103,5 +133,12 @@ fn case_when_all_false(bencher: Bencher, size: usize) { bencher .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| expr.evaluate(array).unwrap()); + .bench_refs(|(expr, array)| { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } diff --git a/vortex-array/src/expr/exprs/case_when.rs b/vortex-array/src/expr/exprs/case_when.rs index 319210137a6..faace3c9062 100644 --- a/vortex-array/src/expr/exprs/case_when.rs +++ b/vortex-array/src/expr/exprs/case_when.rs @@ -287,16 +287,21 @@ pub fn nested_case_when( #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; use vortex_error::VortexExpect as _; use vortex_scalar::Scalar; + use vortex_session::VortexSession; use super::*; + use crate::Canonical; use crate::IntoArray; use crate::ToCanonical; + use crate::VortexSessionExecute as _; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; @@ -307,6 +312,21 @@ mod tests { use crate::expr::exprs::literal::lit; use crate::expr::exprs::root::root; use crate::expr::test_harness; + use crate::session::ArraySession; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Helper to evaluate an expression using the apply+execute pattern + fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef { + let mut ctx = SESSION.create_execution_ctx(); + array + .apply(expr) + .unwrap() + .execute::(&mut ctx) + .unwrap() + .into_array() + } // ==================== Serialization Tests ==================== @@ -455,7 +475,7 @@ mod tests { lit(0i32), ); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); } @@ -475,7 +495,7 @@ mod tests { Some(lit(0i32)), ); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); } @@ -495,7 +515,7 @@ mod tests { Some(lit(0i32)), ); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); } @@ -508,7 +528,7 @@ mod tests { let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32)); - let result = expr.evaluate(&test_array).unwrap(); + let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); assert_eq!( @@ -546,7 +566,7 @@ mod tests { lit(0i32), ); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); } @@ -563,7 +583,7 @@ mod tests { lit(0i32), ); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); } @@ -571,7 +591,7 @@ mod tests { fn test_evaluate_with_literal_condition() { let test_array = buffer![1i32, 2, 3].into_array(); let expr = case_when(lit(true), lit(100i32), lit(0i32)); - let result = expr.evaluate(&test_array).unwrap(); + let result = evaluate_expr(&expr, &test_array); if let Some(constant) = result.as_constant() { assert_eq!(constant, Scalar::from(100i32)); @@ -594,9 +614,9 @@ mod tests { lit(false), ); - let result = expr.evaluate(&test_array).unwrap().to_bool(); + let result = evaluate_expr(&expr, &test_array).to_bool(); assert_eq!( - result.bit_buffer().iter().collect::>(), + result.to_bit_buffer().iter().collect::>(), vec![false, false, true, true, true] ); } @@ -612,7 +632,7 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); } @@ -635,7 +655,7 @@ mod tests { lit(0i32), ); - let result = expr.evaluate(&test_array).unwrap(); + let result = evaluate_expr(&expr, &test_array); let prim = result.to_primitive(); assert_eq!(prim.as_slice::(), &[0, 0, 30, 40, 50]); } @@ -651,12 +671,11 @@ mod tests { let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); - let result = expr.evaluate(&test_array).unwrap().to_primitive(); + let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 0]); } - // Note: Direct execute tests are covered through evaluate tests above, - // since evaluate() calls execute() internally. + // Note: Direct execute tests are covered through apply+execute tests above. // Note: The binary CASE WHEN implementation using `zip` does NOT provide // short-circuit/lazy evaluation. All child expressions are evaluated first, diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index cfb8503810a..2c83f476260 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -847,4 +847,96 @@ mod tests { assert!(!can_be_pushed_down_impl(&like_expr, &test_schema)); } + + /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion + /// matches the result of applying the converted Vortex expression. + #[test] + fn test_case_when_datafusion_vortex_equivalence() { + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::array::RecordBatch; + use datafusion_physical_expr::expressions::CaseExpr; + use vortex::VortexSessionDefault; + use vortex::array::ArrayRef; + use vortex::array::Canonical; + use vortex::array::VortexSessionExecute as _; + use vortex::array::arrow::FromArrowArray; + use vortex::session::VortexSession; + + // Create test data + let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20])); + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + + // Build a DataFusion CASE expression: + // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END + let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc; + let lit_10 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + let lit_5 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc; + let lit_100 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc; + let lit_50 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc; + let lit_0 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc; + + // WHEN value > 10 THEN 100 + let when1 = Arc::new(df_expr::BinaryExpr::new( + col_value.clone(), + DFOperator::Gt, + lit_10, + )) as Arc; + // WHEN value > 5 THEN 50 + let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5)) + as Arc; + + let case_expr = + CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap(); + + // Apply DataFusion expression + let df_result = case_expr.evaluate(&batch).unwrap(); + let df_array = df_result.into_array(batch.num_rows()).unwrap(); + + // Convert to Vortex expression + let expr_convertor = DefaultExpressionConvertor::default(); + let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap(); + + // Convert batch to Vortex array + let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap(); + + // Apply Vortex expression + let session = VortexSession::default(); + let mut ctx = session.create_execution_ctx(); + let vortex_result = vortex_array + .apply(&vortex_expr) + .unwrap() + .execute::(&mut ctx) + .unwrap(); + + // Convert back to Arrow for comparison + let vortex_as_arrow = vortex_result.into_primitive().as_slice::().to_vec(); + + // Convert DataFusion result to Vec for comparison + let df_as_arrow: Vec = df_array + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + // Compare results + // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20] + // value=1: not > 10, not > 5 -> ELSE 0 + // value=5: not > 10, not > 5 -> ELSE 0 + // value=10: not > 10, > 5 -> 50 + // value=15: > 10 -> 100 + // value=20: > 10 -> 100 + assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]); + assert_eq!(vortex_as_arrow, df_as_arrow); + } } From 7fa584ca4ec8f6bbfbf33fc0f3f5de9234d64121 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 5 Feb 2026 12:52:08 -0800 Subject: [PATCH 3/4] feat: N-ary CASE WHEN expressions (#20) * N-ary * fix: update threshold and value literals in n-ary CASE WHEN benchmark * refactor: streamline context creation in benchmarks and improve test readability --- vortex-array/benches/expr/case_when_bench.rs | 138 ++- vortex-array/src/expr/exprs/case_when.rs | 876 ++++++++++++++----- vortex-datafusion/src/convert/exprs.rs | 217 ++++- 3 files changed, 974 insertions(+), 257 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index 183b56c178c..3d795c841bd 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -16,7 +16,6 @@ use vortex_array::expr::case_when; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::lit; -use vortex_array::expr::nested_case_when; use vortex_array::expr::root; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; @@ -45,100 +44,167 @@ fn make_struct_array(size: usize) -> ArrayRef { } /// Benchmark a simple binary CASE WHEN with varying array sizes. -#[divan::bench(args = [1000, 10000, 100000])] +#[divan::bench(args = [10000, 100000, 1000000])] fn case_when_simple(bencher: Bencher, size: usize) { let array = make_struct_array(size); // CASE WHEN value > 500 THEN 100 ELSE 0 END - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(500i32)), lit(100i32), lit(0i32), - ); + ]); bencher - .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| { - let mut ctx = SESSION.create_execution_ctx(); + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { array .apply(expr) .unwrap() - .execute::(&mut ctx) + .execute::(ctx) .unwrap() }); } -/// Benchmark nested CASE WHEN with multiple conditions. +/// Benchmark n-ary CASE WHEN with multiple conditions. #[divan::bench(args = [1000, 10000, 100000])] -fn case_when_nested_3_conditions(bencher: Bencher, size: usize) { +fn case_when_nary_3_conditions(bencher: Bencher, size: usize) { let array = make_struct_array(size); // CASE WHEN value > 750 THEN 3 WHEN value > 500 THEN 2 WHEN value > 250 THEN 1 ELSE 0 END - let expr = nested_case_when( - vec![ - (gt(get_item("value", root()), lit(750i32)), lit(3i32)), - (gt(get_item("value", root()), lit(500i32)), lit(2i32)), - (gt(get_item("value", root()), lit(250i32)), lit(1i32)), - ], - Some(lit(0i32)), - ); + let expr = case_when([ + gt(get_item("value", root()), lit(750i32)), + lit(3i32), + gt(get_item("value", root()), lit(500i32)), + lit(2i32), + gt(get_item("value", root()), lit(250i32)), + lit(1i32), + lit(0i32), + ]); bencher - .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| { - let mut ctx = SESSION.create_execution_ctx(); + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { array .apply(expr) .unwrap() - .execute::(&mut ctx) + .execute::(ctx) .unwrap() }); } /// Benchmark CASE WHEN where all conditions are true (short-circuit path). -#[divan::bench(args = [1000, 10000, 100000])] +#[divan::bench(args = [10000, 100000, 1000000])] fn case_when_all_true(bencher: Bencher, size: usize) { let array = make_struct_array(size); // CASE WHEN value >= 0 THEN 100 ELSE 0 END (always true for our data) - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(-1i32)), lit(100i32), lit(0i32), - ); + ]); bencher - .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| { - let mut ctx = SESSION.create_execution_ctx(); + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { array .apply(expr) .unwrap() - .execute::(&mut ctx) + .execute::(ctx) .unwrap() }); } /// Benchmark CASE WHEN where all conditions are false (short-circuit path). -#[divan::bench(args = [1000, 10000, 100000])] +#[divan::bench(args = [10000, 100000, 1000000])] fn case_when_all_false(bencher: Bencher, size: usize) { let array = make_struct_array(size); // CASE WHEN value > 1000000 THEN 100 ELSE 0 END (always false for our data) - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(1_000_000i32)), lit(100i32), lit(0i32), - ); + ]); + + bencher + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { + array + .apply(expr) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +/// Benchmark n-ary CASE WHEN with 10 conditions. +#[divan::bench(args = [1000, 10000, 100000])] +fn case_when_nary_10_conditions(bencher: Bencher, size: usize) { + let array = make_struct_array(size); + + // Build 10 WHEN/THEN pairs with decreasing thresholds + let expr = case_when([ + gt(get_item("value", root()), lit(900i32)), + lit(10i32), + gt(get_item("value", root()), lit(800i32)), + lit(9i32), + gt(get_item("value", root()), lit(700i32)), + lit(8i32), + gt(get_item("value", root()), lit(600i32)), + lit(7i32), + gt(get_item("value", root()), lit(500i32)), + lit(6i32), + gt(get_item("value", root()), lit(400i32)), + lit(5i32), + gt(get_item("value", root()), lit(300i32)), + lit(4i32), + gt(get_item("value", root()), lit(200i32)), + lit(3i32), + gt(get_item("value", root()), lit(100i32)), + lit(2i32), + gt(get_item("value", root()), lit(0i32)), + lit(1i32), + lit(0i32), // else + ]); + + bencher + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { + array + .apply(expr) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +/// Benchmark n-ary CASE WHEN with 100 conditions. +#[divan::bench(args = [10000, 100000, 1000000])] +fn case_when_nary_100_conditions(bencher: Bencher, size: usize) { + use vortex_array::expr::Expression; + + let array = make_struct_array(size); + + // Build 100 WHEN/THEN pairs programmatically + let mut children: Vec = Vec::with_capacity(201); + for i in (1..=100).rev() { + let threshold = i * 10; // thresholds: 1000, 990, 980, ..., 10 + children.push(gt(get_item("value", root()), lit(threshold))); + children.push(lit(i)); + } + children.push(lit(0i32)); // else + + let expr = case_when(children); bencher - .with_inputs(|| (&expr, &array)) - .bench_refs(|(expr, array)| { - let mut ctx = SESSION.create_execution_ctx(); + .with_inputs(|| (&expr, &array, SESSION.create_execution_ctx())) + .bench_refs(|(expr, array, ctx)| { array .apply(expr) .unwrap() - .execute::(&mut ctx) + .execute::(ctx) .unwrap() }); } diff --git a/vortex-array/src/expr/exprs/case_when.rs b/vortex-array/src/expr/exprs/case_when.rs index faace3c9062..7c1b8d90098 100644 --- a/vortex-array/src/expr/exprs/case_when.rs +++ b/vortex-array/src/expr/exprs/case_when.rs @@ -1,27 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Binary CASE WHEN expression for conditional value selection. +//! N-ary CASE WHEN expression for conditional value selection. //! -//! This expression is a simple wrapper around the `zip` compute function: -//! `CASE WHEN condition THEN value ELSE else_value END` +//! This expression evaluates a series of WHEN conditions and returns the corresponding +//! THEN value for the first condition that evaluates to true. If no conditions match +//! and an ELSE clause is provided, the ELSE value is returned; otherwise, NULL is returned. //! -//! For n-ary CASE WHEN expressions (multiple WHEN clauses), use the -//! [`nested_case_when`] convenience function which converts to nested binary expressions: -//! `CASE WHEN a THEN x WHEN b THEN y ELSE z END` becomes -//! `CASE WHEN a THEN x ELSE (CASE WHEN b THEN y ELSE z END) END` +//! # Structure +//! +//! The expression has children in the following order: +//! - pairs of (condition, value) for each WHEN/THEN clause +//! - optionally, a final ELSE value +//! +//! For example, `CASE WHEN a THEN 1 WHEN b THEN 2 ELSE 3 END` has children: +//! `[a, 1, b, 2, 3]` use std::fmt; use std::fmt::Formatter; use std::hash::Hash; +use std::sync::Arc; use prost::Message; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_panic; use vortex_proto::expr as pb; use vortex_scalar::Scalar; +use vortex_session::VortexSession; use crate::ArrayRef; use crate::IntoArray; @@ -37,26 +43,32 @@ use crate::expr::VTable; use crate::expr::VTableExt; use crate::expr::expression::Expression; -/// Options for the binary CaseWhen expression. +/// Options for the N-ary CaseWhen expression. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct CaseWhenOptions { - /// Whether an ELSE clause is present. + /// Number of WHEN/THEN pairs (each pair contributes 2 children) + pub num_when_then_pairs: u32, + /// Whether an ELSE clause is present (contributes 1 child at the end). /// If false, unmatched rows return NULL. pub has_else: bool, } impl fmt::Display for CaseWhenOptions { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "case_when(else={})", self.has_else) + write!( + f, + "case_when(pairs={}, else={})", + self.num_when_then_pairs, self.has_else + ) } } -/// A binary CASE WHEN expression. +/// An N-ary CASE WHEN expression. /// -/// This is a simple conditional select: `CASE WHEN cond THEN value ELSE else_value END` -/// which is equivalent to `zip(value, else_value, cond)`. +/// Evaluates conditions in order and returns the value corresponding to the +/// first matching condition. /// -/// Children are always in order: [condition, then_value, else_value?] +/// Children are in order: [when_0, then_0, when_1, then_1, ..., else?] pub struct CaseWhen; impl VTable for CaseWhen { @@ -69,40 +81,49 @@ impl VTable for CaseWhen { fn serialize(&self, options: &Self::Options) -> VortexResult>> { Ok(Some( pb::CaseWhenOpts { - // For backwards compatibility, binary is num_when_then_pairs=1 - num_when_then_pairs: 1, + num_when_then_pairs: options.num_when_then_pairs, has_else: options.has_else, } .encode_to_vec(), )) } - fn deserialize(&self, metadata: &[u8]) -> VortexResult { + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { let opts = pb::CaseWhenOpts::decode(metadata)?; - // We only support binary (1 when/then pair) now - if opts.num_when_then_pairs != 1 { - vortex_bail!( - "CaseWhen only supports binary form (1 when/then pair), got {}", - opts.num_when_then_pairs - ); - } Ok(CaseWhenOptions { + num_when_then_pairs: opts.num_when_then_pairs, has_else: opts.has_else, }) } fn arity(&self, options: &Self::Options) -> Arity { - // Binary: condition + then + optional else - let num_children = 2 + if options.has_else { 1 } else { 0 }; + let num_children = + options.num_when_then_pairs as usize * 2 + if options.has_else { 1 } else { 0 }; Arity::Exact(num_children) } fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName { - match child_idx { - 0 => ChildName::from("when"), - 1 => ChildName::from("then"), - 2 if options.has_else => ChildName::from("else"), - _ => unreachable!("Invalid child index {} for binary CaseWhen", child_idx), + let pair_count = options.num_when_then_pairs as usize; + let num_when_then_children = pair_count * 2; + + if child_idx < num_when_then_children { + let pair_idx = child_idx / 2; + if child_idx % 2 == 0 { + ChildName::from(Arc::from(format!("when_{}", pair_idx))) + } else { + ChildName::from(Arc::from(format!("then_{}", pair_idx))) + } + } else if options.has_else && child_idx == num_when_then_children { + ChildName::from("else") + } else { + unreachable!( + "Invalid child index {} for CaseWhen expression with {} pairs", + child_idx, pair_count + ) } } @@ -112,15 +133,27 @@ impl VTable for CaseWhen { expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { - write!(f, "CASE WHEN {} THEN {}", expr.child(0), expr.child(1))?; + write!(f, "CASE")?; + for i in 0..options.num_when_then_pairs as usize { + write!(f, " WHEN ")?; + expr.child(i * 2).fmt_sql(f)?; + write!(f, " THEN ")?; + expr.child(i * 2 + 1).fmt_sql(f)?; + } if options.has_else { - write!(f, " ELSE {}", expr.child(2))?; + let else_idx = options.num_when_then_pairs as usize * 2; + write!(f, " ELSE ")?; + expr.child(else_idx).fmt_sql(f)?; } write!(f, " END") } fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - // The return dtype is based on the THEN expression (index 1) + if options.num_when_then_pairs == 0 { + vortex_bail!("CaseWhen must have at least one WHEN/THEN pair"); + } + + // The return dtype is based on the first THEN expression (index 1) let then_dtype = &arg_dtypes[1]; // If there's no ELSE, the result is always nullable (unmatched rows are NULL) @@ -133,63 +166,61 @@ impl VTable for CaseWhen { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: ExecutionArgs, ) -> VortexResult { let row_count = args.row_count; + let num_pairs = options.num_when_then_pairs as usize; - // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value] - let (condition, then_value, else_value) = match args.inputs.len() { - 2 => { - let [condition, then_value]: [ArrayRef; 2] = args - .inputs - .try_into() - .map_err(|_| vortex_error::vortex_err!("Expected 2 inputs"))?; - (condition, then_value, None) - } - 3 => { - let [condition, then_value, else_value]: [ArrayRef; 3] = args - .inputs - .try_into() - .map_err(|_| vortex_error::vortex_err!("Expected 3 inputs"))?; - (condition, then_value, Some(else_value)) - } - n => vortex_bail!("CaseWhen expects 2 or 3 inputs, got {}", n), + // Single pair case: use efficient binary implementation + if num_pairs == 1 { + return execute_binary_case_when(options.has_else, args); + } + + // N-ary case: evaluate from right to left (innermost first) + // CASE WHEN a THEN x WHEN b THEN y ELSE z END + // evaluates as: CASE WHEN a THEN x ELSE (CASE WHEN b THEN y ELSE z END) END + // + // We iterate from the last pair backwards, building up the result. + + // Start with the else value (or null if no else) + let mut result: ArrayRef = if options.has_else { + let else_idx = num_pairs * 2; + args.inputs[else_idx].clone() + } else { + // Need to determine the output dtype from the first THEN value + let first_then = &args.inputs[1]; + let then_dtype = first_then.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() }; - // Execute condition to get a BoolArray - let cond_bool = condition.execute::(args.ctx)?; - // SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch) - let mask = cond_bool.to_mask_fill_null_false(); + // Process pairs from right to left + for i in (0..num_pairs).rev() { + let cond_idx = i * 2; + let then_idx = i * 2 + 1; - // Short-circuit: all true -> just return THEN value - if mask.all_true() { - return then_value.execute::(args.ctx); - } + let condition = &args.inputs[cond_idx]; + let then_value = &args.inputs[then_idx]; - // Short-circuit: all false -> return ELSE value or NULL - if mask.all_false() { - return match else_value { - Some(else_value) => else_value.execute::(args.ctx), - None => { - // Create NULL constant of appropriate type - let then_dtype = then_value.dtype().as_nullable(); - Ok(ExecutionResult::constant( - Scalar::null(then_dtype), - row_count, - )) - } - }; - } + // Execute condition to get a BoolArray + let cond_bool = condition.clone().execute::(args.ctx)?; + // SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch) + let mask = cond_bool.to_mask_fill_null_false(); - // Get else value for zip (create NULL constant if no else clause) - let else_value = else_value.unwrap_or_else(|| { - let then_dtype = then_value.dtype().as_nullable(); - ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() - }); + // Short-circuit: all true -> just return THEN value for this pair + if mask.all_true() { + result = then_value.clone(); + continue; + } - // Use zip to select: where mask is true, take then_value; else take else_value - let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?; + // Short-circuit: all false -> keep the current result (skip this pair) + if mask.all_false() { + continue; + } + + // Use zip to select: where mask is true, take then_value; else take result + result = zip(then_value.as_ref(), result.as_ref(), &mask)?; + } result.execute::(args.ctx) } @@ -204,91 +235,148 @@ impl VTable for CaseWhen { } } -/// Creates a binary CASE WHEN expression with an ELSE clause. -/// -/// # Arguments -/// - `condition`: Boolean expression for the WHEN clause -/// - `then_value`: Value to return when condition is true -/// - `else_value`: Value to return when condition is false -/// -/// # Example -/// ```ignore -/// // CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END -/// case_when(gt(col("x"), lit(0)), lit("positive"), lit("non-positive")) -/// ``` -pub fn case_when( - condition: Expression, - then_value: Expression, - else_value: Expression, -) -> Expression { - let options = CaseWhenOptions { has_else: true }; - CaseWhen.new_expr(options, [condition, then_value, else_value]) +/// Efficient implementation for binary CASE WHEN (single when/then pair) +fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult { + let row_count = args.row_count; + + // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value] + let (condition, then_value, else_value) = match args.inputs.len() { + 2 => { + let [condition, then_value]: [ArrayRef; 2] = args + .inputs + .try_into() + .map_err(|_| vortex_error::vortex_err!("Expected 2 inputs"))?; + (condition, then_value, None) + } + 3 => { + let [condition, then_value, else_value]: [ArrayRef; 3] = args + .inputs + .try_into() + .map_err(|_| vortex_error::vortex_err!("Expected 3 inputs"))?; + (condition, then_value, Some(else_value)) + } + n => vortex_bail!("Binary CaseWhen expects 2 or 3 inputs, got {}", n), + }; + + // Execute condition to get a BoolArray + let cond_bool = condition.execute::(args.ctx)?; + // SQL semantics: NULL condition is treated as FALSE (i.e., we take the ELSE branch) + let mask = cond_bool.to_mask_fill_null_false(); + + // Short-circuit: all true -> just return THEN value + if mask.all_true() { + return then_value.execute::(args.ctx); + } + + // Short-circuit: all false -> return ELSE value or NULL + if mask.all_false() { + return match else_value { + Some(else_value) => else_value.execute::(args.ctx), + None => { + // Create NULL constant of appropriate type + let then_dtype = then_value.dtype().as_nullable(); + Ok(ExecutionResult::constant( + Scalar::null(then_dtype), + row_count, + )) + } + }; + } + + // Get else value for zip (create NULL constant if no else clause) + let else_value = else_value.unwrap_or_else(|| { + let then_dtype = then_value.dtype().as_nullable(); + ConstantArray::new(Scalar::null(then_dtype), row_count).into_array() + }); + + // Use zip to select: where mask is true, take then_value; else take else_value + let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?; + + result.execute::(args.ctx) } -/// Creates a binary CASE WHEN expression without an ELSE clause. -/// -/// Returns NULL when the condition is false. +/// Creates an N-ary CASE WHEN expression from a flat list of children. /// /// # Arguments -/// - `condition`: Boolean expression for the WHEN clause -/// - `then_value`: Value to return when condition is true +/// - `children`: Iterator of expressions in order: [when_0, then_0, when_1, then_1, ..., else] +/// +/// The last element is always treated as the ELSE clause. +/// +/// # Panics +/// Panics if children has fewer than 3 elements (at least one when/then pair + else required). /// /// # Example /// ```ignore -/// // CASE WHEN x > 0 THEN 'positive' END -/// case_when_no_else(gt(col("x"), lit(0)), lit("positive")) +/// // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END +/// case_when([ +/// gt(col("x"), lit(10)), lit("high"), +/// gt(col("x"), lit(5)), lit("medium"), +/// lit("low"), +/// ]) /// ``` -pub fn case_when_no_else(condition: Expression, then_value: Expression) -> Expression { - let options = CaseWhenOptions { has_else: false }; - CaseWhen.new_expr(options, [condition, then_value]) +#[allow(clippy::cast_possible_truncation)] +pub fn case_when>(children: I) -> Expression { + let children: Vec<_> = children.into_iter().collect(); + assert!( + children.len() >= 3, + "case_when requires at least 3 children (one when/then pair + else)" + ); + assert!( + children.len() % 2 == 1, + "case_when with else must have odd number of children" + ); + + let num_when_then_pairs = (children.len() - 1) / 2; + let options = CaseWhenOptions { + num_when_then_pairs: num_when_then_pairs as u32, + has_else: true, + }; + CaseWhen.new_expr(options, children) } -/// Creates a nested CASE WHEN expression from multiple WHEN/THEN pairs. +/// Creates an N-ary CASE WHEN expression without an ELSE clause. /// -/// This is a convenience function that converts n-ary CASE WHEN to nested binary expressions: -/// `CASE WHEN a THEN x WHEN b THEN y ELSE z END` becomes -/// `CASE WHEN a THEN x ELSE (CASE WHEN b THEN y ELSE z END) END` +/// Returns NULL when no condition matches. /// /// # Arguments -/// - `when_then_pairs`: Vec of (condition, value) pairs -/// - `else_value`: Optional else expression (if None, unmatched rows return NULL) +/// - `children`: Iterator of expressions in order: [when_0, then_0, when_1, then_1, ...] +/// +/// # Panics +/// Panics if children has fewer than 2 elements (at least one when/then pair required). /// /// # Example /// ```ignore -/// // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END -/// nested_case_when( -/// vec![ -/// (gt(col("x"), lit(10)), lit("high")), -/// (gt(col("x"), lit(5)), lit("medium")), -/// ], -/// Some(lit("low")), -/// ) +/// // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' END +/// case_when_no_else([ +/// gt(col("x"), lit(10)), lit("high"), +/// gt(col("x"), lit(5)), lit("medium"), +/// ]) /// ``` -pub fn nested_case_when( - when_then_pairs: Vec<(Expression, Expression)>, - else_value: Option, -) -> Expression { +#[allow(clippy::cast_possible_truncation)] +pub fn case_when_no_else>(children: I) -> Expression { + let children: Vec<_> = children.into_iter().collect(); assert!( - !when_then_pairs.is_empty(), - "nested_case_when requires at least one when/then pair" + children.len() >= 2, + "case_when_no_else requires at least 2 children (one when/then pair)" + ); + assert!( + children.len() % 2 == 0, + "case_when_no_else must have even number of children" ); - // Build from right to left (innermost first) using rfold - when_then_pairs - .into_iter() - .rfold(else_value, |acc, (condition, then_value)| { - Some(match acc { - Some(else_expr) => case_when(condition, then_value, else_expr), - None => case_when_no_else(condition, then_value), - }) - }) - .unwrap_or_else(|| vortex_panic!("rfold on non-empty iterator always produces Some")) + let num_when_then_pairs = children.len() / 2; + let options = CaseWhenOptions { + num_when_then_pairs: num_when_then_pairs as u32, + has_else: false, + }; + CaseWhen.new_expr(options, children) } #[cfg(test)] mod tests { use std::sync::LazyLock; + use vortex_buffer::Buffer; use vortex_buffer::buffer; use vortex_dtype::DType; use vortex_dtype::Nullability; @@ -332,17 +420,27 @@ mod tests { #[test] fn test_serialization_roundtrip() { - let options = CaseWhenOptions { has_else: true }; + let options = CaseWhenOptions { + num_when_then_pairs: 2, + has_else: true, + }; let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); - let deserialized = CaseWhen.deserialize(&serialized).unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); assert_eq!(options, deserialized); } #[test] fn test_serialization_no_else() { - let options = CaseWhenOptions { has_else: false }; + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: false, + }; let serialized = CaseWhen.serialize(&options).unwrap().unwrap(); - let deserialized = CaseWhen.deserialize(&serialized).unwrap(); + let deserialized = CaseWhen + .deserialize(&serialized, &VortexSession::empty()) + .unwrap(); assert_eq!(options, deserialized); } @@ -350,7 +448,7 @@ mod tests { #[test] fn test_display_with_else() { - let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32)); + let expr = case_when([gt(col("value"), lit(0i32)), lit(100i32), lit(0i32)]); let display = format!("{}", expr); assert!(display.contains("CASE")); assert!(display.contains("WHEN")); @@ -361,7 +459,7 @@ mod tests { #[test] fn test_display_no_else() { - let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32)); + let expr = case_when_no_else([gt(col("value"), lit(0i32)), lit(100i32)]); let display = format!("{}", expr); assert!(display.contains("CASE")); assert!(display.contains("WHEN")); @@ -371,28 +469,27 @@ mod tests { } #[test] - fn test_display_nested_nary() { + fn test_display_nary() { // CASE WHEN x > 10 THEN 'high' WHEN x > 5 THEN 'medium' ELSE 'low' END - // Becomes nested: CASE WHEN x>10 THEN 'high' ELSE (CASE WHEN x>5 THEN 'medium' ELSE 'low' END) END - let expr = nested_case_when( - vec![ - (gt(col("x"), lit(10i32)), lit("high")), - (gt(col("x"), lit(5i32)), lit("medium")), - ], - Some(lit("low")), - ); + let expr = case_when([ + gt(col("x"), lit(10i32)), + lit("high"), + gt(col("x"), lit(5i32)), + lit("medium"), + lit("low"), + ]); let display = format!("{}", expr); - // Should contain nested CASE statements - assert_eq!(display.matches("CASE").count(), 2); + // Should contain both WHEN clauses assert_eq!(display.matches("WHEN").count(), 2); assert_eq!(display.matches("THEN").count(), 2); + assert!(display.contains("ELSE")); } // ==================== DType Tests ==================== #[test] fn test_return_dtype_with_else() { - let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let expr = case_when([lit(true), lit(100i32), lit(0i32)]); let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let result_dtype = expr.return_dtype(&input_dtype).unwrap(); assert_eq!( @@ -403,7 +500,7 @@ mod tests { #[test] fn test_return_dtype_without_else_is_nullable() { - let expr = case_when_no_else(lit(true), lit(100i32)); + let expr = case_when_no_else([lit(true), lit(100i32)]); let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let result_dtype = expr.return_dtype(&input_dtype).unwrap(); assert_eq!( @@ -415,11 +512,11 @@ mod tests { #[test] fn test_return_dtype_with_struct_input() { let dtype = test_harness::struct_dtype(); - let expr = case_when( + let expr = case_when([ gt(get_item("col1", root()), lit(10u16)), lit(100i32), lit(0i32), - ); + ]); let result_dtype = expr.return_dtype(&dtype).unwrap(); assert_eq!( result_dtype, @@ -430,32 +527,64 @@ mod tests { // ==================== Arity Tests ==================== #[test] - fn test_arity_with_else() { - let options = CaseWhenOptions { has_else: true }; + fn test_arity_single_pair_with_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; assert_eq!(CaseWhen.arity(&options), Arity::Exact(3)); } #[test] - fn test_arity_without_else() { - let options = CaseWhenOptions { has_else: false }; + fn test_arity_single_pair_without_else() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: false, + }; assert_eq!(CaseWhen.arity(&options), Arity::Exact(2)); } + #[test] + fn test_arity_multiple_pairs() { + let options = CaseWhenOptions { + num_when_then_pairs: 3, + has_else: true, + }; + // 3 pairs * 2 children + 1 else = 7 + assert_eq!(CaseWhen.arity(&options), Arity::Exact(7)); + } + // ==================== Child Name Tests ==================== #[test] - fn test_child_names() { - let options = CaseWhenOptions { has_else: true }; - assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when"); - assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then"); + fn test_child_names_single_pair() { + let options = CaseWhenOptions { + num_when_then_pairs: 1, + has_else: true, + }; + assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0"); + assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0"); assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else"); } + #[test] + fn test_child_names_multiple_pairs() { + let options = CaseWhenOptions { + num_when_then_pairs: 2, + has_else: true, + }; + assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0"); + assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0"); + assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1"); + assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1"); + assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "else"); + } + // ==================== Expression Manipulation Tests ==================== #[test] fn test_replace_children() { - let expr = case_when(lit(true), lit(1i32), lit(0i32)); + let expr = case_when([lit(true), lit(1i32), lit(0i32)]); expr.with_children([lit(false), lit(2i32), lit(3i32)]) .vortex_expect("operation should succeed in test"); } @@ -469,11 +598,11 @@ mod tests { .unwrap() .into_array(); - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(2i32)), lit(100i32), lit(0i32), - ); + ]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); @@ -481,19 +610,18 @@ mod tests { #[test] fn test_evaluate_nary_multiple_conditions() { - // Test n-ary via nested_case_when let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) .unwrap() .into_array(); - let expr = nested_case_when( - vec![ - (eq(get_item("value", root()), lit(1i32)), lit(10i32)), - (eq(get_item("value", root()), lit(3i32)), lit(30i32)), - ], - Some(lit(0i32)), - ); + let expr = case_when([ + eq(get_item("value", root()), lit(1i32)), + lit(10i32), + eq(get_item("value", root()), lit(3i32)), + lit(30i32), + lit(0i32), + ]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[10, 0, 30, 0, 0]); @@ -507,13 +635,13 @@ mod tests { .into_array(); // Both conditions match for values > 3, but first one wins - let expr = nested_case_when( - vec![ - (gt(get_item("value", root()), lit(2i32)), lit(100i32)), - (gt(get_item("value", root()), lit(3i32)), lit(200i32)), - ], - Some(lit(0i32)), - ); + let expr = case_when([ + gt(get_item("value", root()), lit(2i32)), + lit(100i32), + gt(get_item("value", root()), lit(3i32)), + lit(200i32), + lit(0i32), + ]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 100, 100, 100]); @@ -526,7 +654,7 @@ mod tests { .unwrap() .into_array(); - let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32)); + let expr = case_when_no_else([gt(get_item("value", root()), lit(3i32)), lit(100i32)]); let result = evaluate_expr(&expr, &test_array); assert!(result.dtype().is_nullable()); @@ -560,11 +688,11 @@ mod tests { .unwrap() .into_array(); - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(100i32)), lit(1i32), lit(0i32), - ); + ]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 0, 0, 0]); @@ -577,11 +705,11 @@ mod tests { .unwrap() .into_array(); - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(0i32)), lit(100i32), lit(0i32), - ); + ]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); @@ -590,7 +718,7 @@ mod tests { #[test] fn test_evaluate_with_literal_condition() { let test_array = buffer![1i32, 2, 3].into_array(); - let expr = case_when(lit(true), lit(100i32), lit(0i32)); + let expr = case_when([lit(true), lit(100i32), lit(0i32)]); let result = evaluate_expr(&expr, &test_array); if let Some(constant) = result.as_constant() { @@ -608,11 +736,11 @@ mod tests { .unwrap() .into_array(); - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(2i32)), lit(true), lit(false), - ); + ]); let result = evaluate_expr(&expr, &test_array).to_bool(); assert_eq!( @@ -630,7 +758,7 @@ mod tests { .unwrap() .into_array(); - let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + let expr = case_when([get_item("cond", root()), lit(100i32), lit(0i32)]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[100, 0, 0, 0, 100]); @@ -649,11 +777,11 @@ mod tests { .unwrap() .into_array(); - let expr = case_when( + let expr = case_when([ gt(get_item("value", root()), lit(2i32)), get_item("result", root()), lit(0i32), - ); + ]); let result = evaluate_expr(&expr, &test_array); let prim = result.to_primitive(); @@ -669,18 +797,352 @@ mod tests { .unwrap() .into_array(); - let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32)); + let expr = case_when([get_item("cond", root()), lit(100i32), lit(0i32)]); let result = evaluate_expr(&expr, &test_array).to_primitive(); assert_eq!(result.as_slice::(), &[0, 0, 0]); } - // Note: Direct execute tests are covered through apply+execute tests above. + #[test] + fn test_nary_no_else() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); - // Note: The binary CASE WHEN implementation using `zip` does NOT provide - // short-circuit/lazy evaluation. All child expressions are evaluated first, - // then zip selects the result based on the condition. This means expressions - // like divide-by-zero will still fail even if protected by a condition. - // This is intentional - lazy evaluation would require a more complex - // implementation that filters the input before evaluating children. + // CASE WHEN value = 1 THEN 10 WHEN value = 3 THEN 30 END (no else) + let expr = case_when_no_else([ + eq(get_item("value", root()), lit(1i32)), + lit(10i32), + eq(get_item("value", root()), lit(3i32)), + lit(30i32), + ]); + + let result = evaluate_expr(&expr, &test_array); + assert!(result.dtype().is_nullable()); + + // Values 1 -> 10, 3 -> 30, others -> NULL + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::from(10i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(1).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(2).unwrap(), + Scalar::from(30i32).cast(result.dtype()).unwrap() + ); + assert_eq!( + result.scalar_at(3).unwrap(), + Scalar::null(result.dtype().clone()) + ); + assert_eq!( + result.scalar_at(4).unwrap(), + Scalar::null(result.dtype().clone()) + ); + } + + // ==================== Advanced N-ary Tests ==================== + + #[test] + fn test_nary_5_conditions() { + // Test with 5 when/then pairs to stress the n-ary implementation + let test_array = StructArray::from_fields(&[( + "value", + buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(), + )]) + .unwrap() + .into_array(); + + // CASE WHEN value=1 THEN 100 WHEN value=3 THEN 300 WHEN value=5 THEN 500 + // WHEN value=7 THEN 700 WHEN value=9 THEN 900 ELSE 0 END + let expr = case_when([ + eq(get_item("value", root()), lit(1i32)), + lit(100i32), + eq(get_item("value", root()), lit(3i32)), + lit(300i32), + eq(get_item("value", root()), lit(5i32)), + lit(500i32), + eq(get_item("value", root()), lit(7i32)), + lit(700i32), + eq(get_item("value", root()), lit(9i32)), + lit(900i32), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!( + result.as_slice::(), + &[100, 0, 300, 0, 500, 0, 700, 0, 900, 0] + ); + } + + #[test] + fn test_nary_all_conditions_short_circuit_true() { + // First condition matches all rows + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when([ + gt(get_item("value", root()), lit(0i32)), // Always true + lit(100i32), + gt(get_item("value", root()), lit(3i32)), // Would match some + lit(200i32), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // All rows should match first condition + assert_eq!(result.as_slice::(), &[100, 100, 100, 100, 100]); + } + + #[test] + fn test_nary_all_conditions_false() { + // No conditions match, should return else value + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when([ + gt(get_item("value", root()), lit(100i32)), // Never true + lit(100i32), + gt(get_item("value", root()), lit(200i32)), // Never true + lit(200i32), + lit(999i32), // Else + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + assert_eq!(result.as_slice::(), &[999, 999, 999, 999, 999]); + } + + #[test] + fn test_nary_cascading_conditions() { + // Test cascading conditions where later conditions catch what earlier ones miss + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 5, 10, 15, 20].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 15 THEN 4 WHEN value > 10 THEN 3 WHEN value > 5 THEN 2 WHEN value > 1 THEN 1 ELSE 0 END + let expr = case_when([ + gt(get_item("value", root()), lit(15i32)), + lit(4i32), + gt(get_item("value", root()), lit(10i32)), + lit(3i32), + gt(get_item("value", root()), lit(5i32)), + lit(2i32), + gt(get_item("value", root()), lit(1i32)), + lit(1i32), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // value=1 -> 0, value=5 -> 1, value=10 -> 2, value=15 -> 3, value=20 -> 4 + assert_eq!(result.as_slice::(), &[0, 1, 2, 3, 4]); + } + + #[test] + fn test_nary_with_nullable_conditions() { + // Test n-ary with nullable conditions in the middle + let test_array = StructArray::from_fields(&[ + ("value", buffer![1i32, 2, 3, 4, 5].into_array()), + ( + "cond1", + BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]) + .into_array(), + ), + ( + "cond2", + BoolArray::from_iter([ + Some(false), + Some(true), + Some(true), + Some(false), + Some(false), + ]) + .into_array(), + ), + ]) + .unwrap() + .into_array(); + + // CASE WHEN cond1 THEN 100 WHEN cond2 THEN 200 ELSE 0 END + let expr = case_when([ + get_item("cond1", root()), + lit(100i32), + get_item("cond2", root()), + lit(200i32), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // row 0: cond1=true -> 100 + // row 1: cond1=null(false) -> cond2=true -> 200 + // row 2: cond1=false -> cond2=true -> 200 + // row 3: cond1=null(false) -> cond2=false -> 0 + // row 4: cond1=true -> 100 + assert_eq!(result.as_slice::(), &[100, 200, 200, 0, 100]); + } + + #[test] + fn test_nary_no_else_all_unmatched() { + // N-ary without else where no conditions match + let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())]) + .unwrap() + .into_array(); + + let expr = case_when_no_else([ + gt(get_item("value", root()), lit(100i32)), // Never true + lit(100i32), + gt(get_item("value", root()), lit(200i32)), // Never true + lit(200i32), + ]); + + let result = evaluate_expr(&expr, &test_array); + assert!(result.dtype().is_nullable()); + + // All values should be NULL + for i in 0..3 { + assert_eq!( + result.scalar_at(i).unwrap(), + Scalar::null(result.dtype().clone()) + ); + } + } + + #[test] + fn test_large_array() { + // Test with a larger array to verify performance characteristics + let size: i32 = 10000; + let data: Vec = (0..size).collect(); + let test_array = StructArray::from_fields(&[("value", Buffer::from(data).into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value > 7500 THEN 3 WHEN value > 5000 THEN 2 WHEN value > 2500 THEN 1 ELSE 0 END + let expr = case_when([ + gt(get_item("value", root()), lit(7500i32)), + lit(3i32), + gt(get_item("value", root()), lit(5000i32)), + lit(2i32), + gt(get_item("value", root()), lit(2500i32)), + lit(1i32), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + let result_slice = result.as_slice::(); + + // Verify correctness at boundaries + assert_eq!(result_slice[0], 0); // value=0 -> 0 + assert_eq!(result_slice[2500], 0); // value=2500 -> 0 + assert_eq!(result_slice[2501], 1); // value=2501 -> 1 + assert_eq!(result_slice[5000], 1); // value=5000 -> 1 + assert_eq!(result_slice[5001], 2); // value=5001 -> 2 + assert_eq!(result_slice[7500], 2); // value=7500 -> 2 + assert_eq!(result_slice[7501], 3); // value=7501 -> 3 + assert_eq!(result_slice[9999], 3); // value=9999 -> 3 + } + + #[test] + fn test_nary_with_column_results() { + // Test n-ary where THEN values come from columns, not literals + let test_array = StructArray::from_fields(&[ + ("value", buffer![1i32, 2, 3, 4, 5].into_array()), + ("result_a", buffer![10i32, 20, 30, 40, 50].into_array()), + ("result_b", buffer![100i32, 200, 300, 400, 500].into_array()), + ]) + .unwrap() + .into_array(); + + // CASE WHEN value < 3 THEN result_a WHEN value > 3 THEN result_b ELSE 0 END + let expr = case_when([ + gt(lit(3i32), get_item("value", root())), // value < 3 + get_item("result_a", root()), + gt(get_item("value", root()), lit(3i32)), // value > 3 + get_item("result_b", root()), + lit(0i32), + ]); + + let result = evaluate_expr(&expr, &test_array).to_primitive(); + // value=1 -> result_a=10, value=2 -> result_a=20, value=3 -> else=0 + // value=4 -> result_b=400, value=5 -> result_b=500 + assert_eq!(result.as_slice::(), &[10, 20, 0, 400, 500]); + } + + #[test] + fn test_string_results() { + let test_array = + StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())]) + .unwrap() + .into_array(); + + // CASE WHEN value < 2 THEN 'low' WHEN value < 4 THEN 'medium' ELSE 'high' END + let expr = case_when([ + gt(lit(2i32), get_item("value", root())), + lit("low"), + gt(lit(4i32), get_item("value", root())), + lit("medium"), + lit("high"), + ]); + + let result = evaluate_expr(&expr, &test_array); + let varbinview = result.to_varbinview(); + + assert_eq!( + varbinview + .scalar_at(0) + .unwrap() + .as_utf8() + .value() + .unwrap() + .as_str(), + "low" + ); + assert_eq!( + varbinview + .scalar_at(1) + .unwrap() + .as_utf8() + .value() + .unwrap() + .as_str(), + "medium" + ); + assert_eq!( + varbinview + .scalar_at(2) + .unwrap() + .as_utf8() + .value() + .unwrap() + .as_str(), + "medium" + ); + assert_eq!( + varbinview + .scalar_at(3) + .unwrap() + .as_utf8() + .value() + .unwrap() + .as_str(), + "high" + ); + assert_eq!( + varbinview + .scalar_at(4) + .unwrap() + .as_utf8() + .value() + .unwrap() + .as_str(), + "high" + ); + } } diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index 2c83f476260..f32be050482 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -29,12 +29,13 @@ use vortex::expr::Like; use vortex::expr::Operator; use vortex::expr::VTableExt; use vortex::expr::and; +use vortex::expr::case_when; +use vortex::expr::case_when_no_else; use vortex::expr::cast; use vortex::expr::get_item; use vortex::expr::is_null; use vortex::expr::list_contains; use vortex::expr::lit; -use vortex::expr::nested_case_when; use vortex::expr::not; use vortex::expr::pack; use vortex::expr::root; @@ -137,22 +138,20 @@ impl DefaultExpressionConvertor { )); } - // Convert all when/then pairs to (condition, value) tuples - let mut pairs = Vec::with_capacity(when_then_pairs.len()); + // Convert all when/then pairs to flat list: [when1, then1, when2, then2, ...] + let mut children = Vec::with_capacity(when_then_pairs.len() * 2 + 1); for (when_expr, then_expr) in when_then_pairs { - let condition = self.convert(when_expr.as_ref())?; - let value = self.convert(then_expr.as_ref())?; - pairs.push((condition, value)); + children.push(self.convert(when_expr.as_ref())?); + children.push(self.convert(then_expr.as_ref())?); } - // Convert optional else expression - let else_value = case_expr - .else_expr() - .map(|e| self.convert(e.as_ref())) - .transpose()?; - - // Use nested_case_when which converts to nested binary case_when expressions - Ok(nested_case_when(pairs, else_value)) + // Handle optional else clause + if let Some(else_expr) = case_expr.else_expr() { + children.push(self.convert(else_expr.as_ref())?); + Ok(case_when(children)) + } else { + Ok(case_when_no_else(children)) + } } } @@ -939,4 +938,194 @@ mod tests { assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]); assert_eq!(vortex_as_arrow, df_as_arrow); } + + #[test] + fn test_case_when_nary_4_conditions() { + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::array::RecordBatch; + use datafusion_physical_expr::expressions::CaseExpr; + use vortex::VortexSessionDefault; + use vortex::array::ArrayRef; + use vortex::array::Canonical; + use vortex::array::VortexSessionExecute as _; + use vortex::array::arrow::FromArrowArray; + use vortex::session::VortexSession; + + // Test with 4 when/then pairs to exercise the n-ary implementation + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + + let values = Int32Array::from(vec![1, 10, 25, 50, 75, 100]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(values)]).unwrap(); + + let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc; + + // Literal thresholds + let lit_75 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(75)))) as Arc; + let lit_50 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc; + let lit_25 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(25)))) as Arc; + let lit_10 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + + // Result values + let result_4 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(4)))) as Arc; + let result_3 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(3)))) as Arc; + let result_2 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let result_1 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let result_0 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc; + + // Build conditions + let when1 = Arc::new(df_expr::BinaryExpr::new( + col_value.clone(), + DFOperator::Gt, + lit_75, + )) as Arc; + let when2 = Arc::new(df_expr::BinaryExpr::new( + col_value.clone(), + DFOperator::Gt, + lit_50, + )) as Arc; + let when3 = Arc::new(df_expr::BinaryExpr::new( + col_value.clone(), + DFOperator::Gt, + lit_25, + )) as Arc; + let when4 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_10)) + as Arc; + + let case_expr = CaseExpr::try_new( + None, + vec![ + (when1, result_4), + (when2, result_3), + (when3, result_2), + (when4, result_1), + ], + Some(result_0), + ) + .unwrap(); + + // Apply DataFusion expression + let df_result = case_expr.evaluate(&batch).unwrap(); + let df_array = df_result.into_array(batch.num_rows()).unwrap(); + + // Convert to Vortex expression + let expr_convertor = DefaultExpressionConvertor::default(); + let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap(); + + // Convert batch to Vortex array + let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap(); + + // Apply Vortex expression + let session = VortexSession::default(); + let mut ctx = session.create_execution_ctx(); + let vortex_result = vortex_array + .apply(&vortex_expr) + .unwrap() + .execute::(&mut ctx) + .unwrap(); + + // Convert to Vec for comparison + let vortex_as_arrow = vortex_result.into_primitive().as_slice::().to_vec(); + let df_as_arrow: Vec = df_array + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + // Expected: for values [1, 10, 25, 50, 75, 100] + // 1: not > 75, not > 50, not > 25, not > 10 -> 0 + // 10: not > 75, not > 50, not > 25, not > 10 -> 0 + // 25: not > 75, not > 50, not > 25, > 10 -> 1 + // 50: not > 75, not > 50, > 25 -> 2 + // 75: not > 75, > 50 -> 3 + // 100: > 75 -> 4 + assert_eq!(df_as_arrow, vec![0, 0, 1, 2, 3, 4]); + assert_eq!(vortex_as_arrow, df_as_arrow); + } + + #[test] + fn test_case_when_no_else() { + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::array::RecordBatch; + use datafusion_physical_expr::expressions::CaseExpr; + use vortex::VortexSessionDefault; + use vortex::array::ArrayRef; + use vortex::array::Canonical; + use vortex::array::VortexSessionExecute as _; + use vortex::array::arrow::FromArrowArray; + use vortex::session::VortexSession; + + // Test CASE WHEN without ELSE clause + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + + let values = Int32Array::from(vec![1, 5, 10, 15, 20]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(values)]).unwrap(); + + let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc; + let lit_10 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc; + let lit_100 = + Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc; + + // WHEN value > 10 THEN 100 (no else) + let when1 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_10)) + as Arc; + + let case_expr = CaseExpr::try_new(None, vec![(when1, lit_100)], None).unwrap(); + + // Convert to Vortex expression + let expr_convertor = DefaultExpressionConvertor::default(); + let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap(); + + // Verify the expression was created with no else + assert!(vortex_expr.to_string().contains("CASE")); + assert!(!vortex_expr.to_string().contains("ELSE")); + + // Convert batch to Vortex array + let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap(); + + // Apply Vortex expression - result should be nullable + let session = VortexSession::default(); + let mut ctx = session.create_execution_ctx(); + let vortex_result = vortex_array + .apply(&vortex_expr) + .unwrap() + .execute::(&mut ctx) + .unwrap(); + + // Result should be nullable + assert!(vortex_result.dtype().is_nullable()); + + // Convert result to primitive and check values + let prim = vortex_result.into_primitive(); + + // Check non-null values + // value=1: not > 10 -> NULL + // value=5: not > 10 -> NULL + // value=10: not > 10 -> NULL + // value=15: > 10 -> 100 + // value=20: > 10 -> 100 + assert!(prim.scalar_at(0).unwrap().is_null()); + assert!(prim.scalar_at(1).unwrap().is_null()); + assert!(prim.scalar_at(2).unwrap().is_null()); + assert_eq!(prim.as_slice::()[3], 100); + assert_eq!(prim.as_slice::()[4], 100); + } } From 0a97ebbc47383a501657f7fcfd76525ec7068411 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:30:14 -0800 Subject: [PATCH 4/4] refactor: update CASE WHEN options to use a single num_children field (#22) --- vortex-array/benches/expr/case_when_bench.rs | 8 ++-- vortex-array/src/expr/exprs/case_when.rs | 44 +++++++++++--------- vortex-proto/proto/expr.proto | 7 +++- vortex-proto/src/generated/vortex.expr.rs | 8 ++-- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/vortex-array/benches/expr/case_when_bench.rs b/vortex-array/benches/expr/case_when_bench.rs index 3d795c841bd..862f3f1b53e 100644 --- a/vortex-array/benches/expr/case_when_bench.rs +++ b/vortex-array/benches/expr/case_when_bench.rs @@ -44,7 +44,7 @@ fn make_struct_array(size: usize) -> ArrayRef { } /// Benchmark a simple binary CASE WHEN with varying array sizes. -#[divan::bench(args = [10000, 100000, 1000000])] +#[divan::bench(args = [1000, 10000, 100000])] fn case_when_simple(bencher: Bencher, size: usize) { let array = make_struct_array(size); @@ -94,7 +94,7 @@ fn case_when_nary_3_conditions(bencher: Bencher, size: usize) { } /// Benchmark CASE WHEN where all conditions are true (short-circuit path). -#[divan::bench(args = [10000, 100000, 1000000])] +#[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_true(bencher: Bencher, size: usize) { let array = make_struct_array(size); @@ -117,7 +117,7 @@ fn case_when_all_true(bencher: Bencher, size: usize) { } /// Benchmark CASE WHEN where all conditions are false (short-circuit path). -#[divan::bench(args = [10000, 100000, 1000000])] +#[divan::bench(args = [1000, 10000, 100000])] fn case_when_all_false(bencher: Bencher, size: usize) { let array = make_struct_array(size); @@ -181,7 +181,7 @@ fn case_when_nary_10_conditions(bencher: Bencher, size: usize) { } /// Benchmark n-ary CASE WHEN with 100 conditions. -#[divan::bench(args = [10000, 100000, 1000000])] +#[divan::bench(args = [1000, 10000, 100000])] fn case_when_nary_100_conditions(bencher: Bencher, size: usize) { use vortex_array::expr::Expression; diff --git a/vortex-array/src/expr/exprs/case_when.rs b/vortex-array/src/expr/exprs/case_when.rs index 7c1b8d90098..37915233398 100644 --- a/vortex-array/src/expr/exprs/case_when.rs +++ b/vortex-array/src/expr/exprs/case_when.rs @@ -37,7 +37,6 @@ use crate::compute::zip; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; -use crate::expr::ExecutionResult; use crate::expr::ExprId; use crate::expr::VTable; use crate::expr::VTableExt; @@ -79,12 +78,10 @@ impl VTable for CaseWhen { } fn serialize(&self, options: &Self::Options) -> VortexResult>> { + let num_children = + options.num_when_then_pairs * 2 + if options.has_else { 1 } else { 0 }; Ok(Some( - pb::CaseWhenOpts { - num_when_then_pairs: options.num_when_then_pairs, - has_else: options.has_else, - } - .encode_to_vec(), + pb::CaseWhenOpts { num_children }.encode_to_vec(), )) } @@ -95,8 +92,8 @@ impl VTable for CaseWhen { ) -> VortexResult { let opts = pb::CaseWhenOpts::decode(metadata)?; Ok(CaseWhenOptions { - num_when_then_pairs: opts.num_when_then_pairs, - has_else: opts.has_else, + num_when_then_pairs: opts.num_children / 2, + has_else: opts.num_children % 2 == 1, }) } @@ -156,6 +153,18 @@ impl VTable for CaseWhen { // The return dtype is based on the first THEN expression (index 1) let then_dtype = &arg_dtypes[1]; + // All THEN (and ELSE) value dtypes must match + debug_assert!( + (0..options.num_when_then_pairs as usize).all(|i| { + let idx = i * 2 + 1; + &arg_dtypes[idx] == then_dtype + }), + "All THEN expression dtypes must match, got {:?}", + (0..options.num_when_then_pairs as usize) + .map(|i| &arg_dtypes[i * 2 + 1]) + .collect::>() + ); + // If there's no ELSE, the result is always nullable (unmatched rows are NULL) if !options.has_else { Ok(then_dtype.as_nullable()) @@ -168,7 +177,7 @@ impl VTable for CaseWhen { &self, options: &Self::Options, args: ExecutionArgs, - ) -> VortexResult { + ) -> VortexResult { let row_count = args.row_count; let num_pairs = options.num_when_then_pairs as usize; @@ -222,7 +231,7 @@ impl VTable for CaseWhen { result = zip(then_value.as_ref(), result.as_ref(), &mask)?; } - result.execute::(args.ctx) + Ok(result) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { @@ -236,7 +245,7 @@ impl VTable for CaseWhen { } /// Efficient implementation for binary CASE WHEN (single when/then pair) -fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult { +fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult { let row_count = args.row_count; // Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value] @@ -265,20 +274,17 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul // Short-circuit: all true -> just return THEN value if mask.all_true() { - return then_value.execute::(args.ctx); + return Ok(then_value); } // Short-circuit: all false -> return ELSE value or NULL if mask.all_false() { return match else_value { - Some(else_value) => else_value.execute::(args.ctx), + Some(else_value) => Ok(else_value), None => { // Create NULL constant of appropriate type let then_dtype = then_value.dtype().as_nullable(); - Ok(ExecutionResult::constant( - Scalar::null(then_dtype), - row_count, - )) + Ok(ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()) } }; } @@ -290,9 +296,7 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul }); // Use zip to select: where mask is true, take then_value; else take else_value - let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?; - - result.execute::(args.ctx) + zip(then_value.as_ref(), else_value.as_ref(), &mask) } /// Creates an N-ary CASE WHEN expression from a flat list of children. diff --git a/vortex-proto/proto/expr.proto b/vortex-proto/proto/expr.proto index 7d713a3afb3..3b47db2a756 100644 --- a/vortex-proto/proto/expr.proto +++ b/vortex-proto/proto/expr.proto @@ -82,7 +82,10 @@ message SelectOpts { } // Options for `vortex.case_when` +// Encodes num_when_then_pairs and has_else into a single u32 (num_children). +// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) +// has_else = num_children % 2 == 1 +// num_when_then_pairs = num_children / 2 message CaseWhenOpts { - uint32 num_when_then_pairs = 1; - bool has_else = 2; + uint32 num_children = 1; } diff --git a/vortex-proto/src/generated/vortex.expr.rs b/vortex-proto/src/generated/vortex.expr.rs index 180e693f269..9bc61475e59 100644 --- a/vortex-proto/src/generated/vortex.expr.rs +++ b/vortex-proto/src/generated/vortex.expr.rs @@ -146,10 +146,12 @@ pub mod select_opts { } } /// Options for `vortex.case_when` +/// Encodes num_when_then_pairs and has_else into a single u32 (num_children). +/// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0) +/// has_else = num_children % 2 == 1 +/// num_when_then_pairs = num_children / 2 #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct CaseWhenOpts { #[prost(uint32, tag = "1")] - pub num_when_then_pairs: u32, - #[prost(bool, tag = "2")] - pub has_else: bool, + pub num_children: u32, }