Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,31 @@ struct ProjectedCaseBody {
/// [WHEN ...]
/// [ELSE result]
/// END
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Debug)]
pub struct CaseExpr {
/// The case expression body
body: CaseBody,
/// Evaluation method to use
eval_method: EvalMethod,
}

// eval_method is functionally derived from body, so excluding it from
// Hash/Eq avoids redundantly hashing the expression tree twice. For
// nested CASE chains this prevents exponential blowup (see #22173).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// nested CASE chains this prevents exponential blowup (see #22173).
// nested CASE chains this prevents exponential blowup (see https://github.com/apache/datafusion/issues/22173).

impl Hash for CaseExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.body.hash(state);
}
}

impl PartialEq for CaseExpr {
fn eq(&self, other: &Self) -> bool {
self.body == other.body
}
}

impl Eq for CaseExpr {}

impl std::fmt::Display for CaseExpr {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "CASE ")?;
Expand Down Expand Up @@ -3120,4 +3137,59 @@ mod tests {
Arc::new(expected_with_else),
);
}

/// Reproduces https://github.com/apache/datafusion/issues/22173
///
/// Nested self-referential CASE chains (common in rewrite-style projections)
/// should not cause exponential hashing work during physical planning.
#[test]
fn nested_self_referential_case_hash_stays_bounded() -> Result<()> {
use std::hash::Hasher;

#[derive(Default)]
struct CountingHasher {
write_calls: usize,
bytes_written: usize,
}

impl Hasher for CountingHasher {
fn finish(&self) -> u64 {
0
}

fn write(&mut self, bytes: &[u8]) {
self.write_calls += 1;
self.bytes_written += bytes.len();
}
}

let schema =
Arc::new(Schema::new(vec![Field::new("kind", DataType::Utf8, true)]));

let kind = col("kind", &schema)?;
let mut label = Arc::clone(&kind);

let num_levels = 18;
for idx in 0..num_levels {
let predicate = Arc::new(BinaryExpr::new(
Arc::clone(&kind),
Operator::Eq,
lit(idx.to_string()),
)) as Arc<dyn PhysicalExpr>;

label = case(None, vec![(predicate, lit("label"))], Some(label))?;
}

let mut hasher = CountingHasher::default();
label.hash(&mut hasher);

assert!(
hasher.write_calls < 50_000,
"hashing nested CASE expression took {} hasher writes and {} bytes",
hasher.write_calls,
hasher.bytes_written
);

Ok(())
}
}
Loading