diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 20d0a9e97e833..568ecb9cf336b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -268,7 +268,7 @@ struct ProjectedCaseBody { /// [WHEN ...] /// [ELSE result] /// END -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Debug)] pub struct CaseExpr { /// The case expression body body: CaseBody, @@ -276,6 +276,23 @@ pub struct CaseExpr { eval_method: EvalMethod, } +// eval_method is functionally derived from body, so excluding it from +// Hash/Eq avoids redundantly hashing the expression tree twice. For +// nested CASE chains this prevents exponential blowup (see #22173). +impl Hash for CaseExpr { + fn hash(&self, state: &mut H) { + self.body.hash(state); + } +} + +impl PartialEq for CaseExpr { + fn eq(&self, other: &Self) -> bool { + self.body == other.body + } +} + +impl Eq for CaseExpr {} + impl std::fmt::Display for CaseExpr { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "CASE ")?; @@ -3120,4 +3137,59 @@ mod tests { Arc::new(expected_with_else), ); } + + /// Reproduces https://github.com/apache/datafusion/issues/22173 + /// + /// Nested self-referential CASE chains (common in rewrite-style projections) + /// should not cause exponential hashing work during physical planning. + #[test] + fn nested_self_referential_case_hash_stays_bounded() -> Result<()> { + use std::hash::Hasher; + + #[derive(Default)] + struct CountingHasher { + write_calls: usize, + bytes_written: usize, + } + + impl Hasher for CountingHasher { + fn finish(&self) -> u64 { + 0 + } + + fn write(&mut self, bytes: &[u8]) { + self.write_calls += 1; + self.bytes_written += bytes.len(); + } + } + + let schema = + Arc::new(Schema::new(vec![Field::new("kind", DataType::Utf8, true)])); + + let kind = col("kind", &schema)?; + let mut label = Arc::clone(&kind); + + let num_levels = 18; + for idx in 0..num_levels { + let predicate = Arc::new(BinaryExpr::new( + Arc::clone(&kind), + Operator::Eq, + lit(idx.to_string()), + )) as Arc; + + label = case(None, vec![(predicate, lit("label"))], Some(label))?; + } + + let mut hasher = CountingHasher::default(); + label.hash(&mut hasher); + + assert!( + hasher.write_calls < 50_000, + "hashing nested CASE expression took {} hasher writes and {} bytes", + hasher.write_calls, + hasher.bytes_written + ); + + Ok(()) + } }