Skip to content

Commit f0f2e90

Browse files
fix: InList Dictionary filter pushdown type mismatch
Guard the dictionary unwrap in ArrayStaticFilter::contains() to only fire when the dictionary value type matches in_array's type. This prevents a make_comparator type mismatch when both the needle and in_array are Dictionary-encoded, which occurs in HashJoin dynamic filter pushdown with pushdown_filters enabled. Closes #20937
1 parent 3ece9ec commit f0f2e90

2 files changed

Lines changed: 210 additions & 4 deletions

File tree

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 204 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,16 @@ impl StaticFilter for ArrayStaticFilter {
9999
));
100100
}
101101

102+
// Unwrap dictionary-encoded needles when the value type matches
103+
// in_array, evaluating against distinct values and mapping back
104+
// via keys.
102105
downcast_dictionary_array! {
103106
v => {
104-
let values_contains = self.contains(v.values().as_ref(), negated)?;
105-
let result = take(&values_contains, v.keys(), None)?;
106-
return Ok(downcast_array(result.as_ref()))
107+
if v.values().data_type() == self.in_array.data_type() {
108+
let values_contains = self.contains(v.values().as_ref(), negated)?;
109+
let result = take(&values_contains, v.keys(), None)?;
110+
return Ok(downcast_array(result.as_ref()));
111+
}
107112
}
108113
_ => {}
109114
}
@@ -3878,4 +3883,200 @@ mod tests {
38783883
);
38793884
Ok(())
38803885
}
3886+
3887+
// -----------------------------------------------------------------------
3888+
// Tests for try_new_from_array covering all (in_array, needle) type
3889+
// combinations that occur in HashJoin dynamic filter pushdown.
3890+
//
3891+
// try_new (used by SQL IN expressions) always produces a non-Dictionary
3892+
// in_array because evaluate_list() flattens Dictionary scalars to their
3893+
// value type. try_new_from_array passes the array directly, so it is
3894+
// the only path that can produce a Dictionary in_array.
3895+
// -----------------------------------------------------------------------
3896+
3897+
fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
3898+
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
3899+
Arc::new(DictionaryArray::new(keys, array))
3900+
}
3901+
3902+
fn eval_in_list_from_array(
3903+
needle_type: DataType,
3904+
needle: ArrayRef,
3905+
in_array: ArrayRef,
3906+
) -> Result<BooleanArray> {
3907+
let schema = Schema::new(vec![Field::new("a", needle_type, false)]);
3908+
let col_a = col("a", &schema)?;
3909+
let expr = Arc::new(InListExpr::try_new_from_array(
3910+
col_a, in_array, false,
3911+
)?) as Arc<dyn PhysicalExpr>;
3912+
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
3913+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3914+
Ok(as_boolean_array(&result).clone())
3915+
}
3916+
3917+
#[test]
3918+
fn test_in_list_from_array_type_combinations() -> Result<()> {
3919+
use arrow::compute::cast;
3920+
3921+
// All cases: needle[0] and needle[2] match, needle[1] does not.
3922+
let expected =
3923+
BooleanArray::from(vec![Some(true), Some(false), Some(true)]);
3924+
3925+
// Base arrays cast to each target type
3926+
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
3927+
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;
3928+
3929+
// Test all specializations in instantiate_static_filter
3930+
let primitive_types = vec![
3931+
DataType::Int8,
3932+
DataType::Int16,
3933+
DataType::Int32,
3934+
DataType::Int64,
3935+
DataType::UInt8,
3936+
DataType::UInt16,
3937+
DataType::UInt32,
3938+
DataType::UInt64,
3939+
DataType::Float32,
3940+
DataType::Float64,
3941+
];
3942+
3943+
for dt in &primitive_types {
3944+
let in_array = cast(&base_in, dt)?;
3945+
let needle = cast(&base_needle, dt)?;
3946+
3947+
// T in_array, T needle
3948+
assert_eq!(expected, eval_in_list_from_array(
3949+
dt.clone(), Arc::clone(&needle), Arc::clone(&in_array),
3950+
)?, "same-type failed for {dt:?}");
3951+
3952+
// T in_array, Dict(Int32, T) needle
3953+
let dict_dt = DataType::Dictionary(
3954+
Box::new(DataType::Int32),
3955+
Box::new(dt.clone()),
3956+
);
3957+
assert_eq!(expected, eval_in_list_from_array(
3958+
dict_dt, wrap_in_dict(needle), in_array,
3959+
)?, "dict-needle failed for {dt:?}");
3960+
}
3961+
3962+
// Utf8 (falls through to ArrayStaticFilter)
3963+
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3964+
let utf8_needle =
3965+
Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
3966+
let dict_utf8 = DataType::Dictionary(
3967+
Box::new(DataType::Int32),
3968+
Box::new(DataType::Utf8),
3969+
);
3970+
3971+
// Utf8 in_array, Utf8 needle
3972+
assert_eq!(expected, eval_in_list_from_array(
3973+
DataType::Utf8,
3974+
Arc::clone(&utf8_needle),
3975+
Arc::clone(&utf8_in),
3976+
)?);
3977+
3978+
// Utf8 in_array, Dict(Utf8) needle
3979+
assert_eq!(expected, eval_in_list_from_array(
3980+
dict_utf8.clone(),
3981+
wrap_in_dict(Arc::clone(&utf8_needle)),
3982+
Arc::clone(&utf8_in),
3983+
)?);
3984+
3985+
// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3986+
assert_eq!(expected, eval_in_list_from_array(
3987+
dict_utf8,
3988+
wrap_in_dict(Arc::clone(&utf8_needle)),
3989+
wrap_in_dict(Arc::clone(&utf8_in)),
3990+
)?);
3991+
3992+
// Struct in_array, Struct needle: multi-column join
3993+
let struct_fields = Fields::from(vec![
3994+
Field::new("c0", DataType::Utf8, true),
3995+
Field::new("c1", DataType::Int64, true),
3996+
]);
3997+
let struct_type = DataType::Struct(struct_fields.clone());
3998+
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
3999+
let pairs: Vec<(FieldRef, ArrayRef)> =
4000+
struct_fields.iter().cloned().zip([c0, c1]).collect();
4001+
Arc::new(StructArray::from(pairs))
4002+
};
4003+
assert_eq!(expected, eval_in_list_from_array(
4004+
struct_type,
4005+
make_struct(
4006+
Arc::clone(&utf8_needle),
4007+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4008+
),
4009+
make_struct(
4010+
Arc::clone(&utf8_in),
4011+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4012+
),
4013+
)?);
4014+
4015+
// Struct with Dict fields: multi-column Dict join
4016+
let dict_struct_fields = Fields::from(vec![
4017+
Field::new("c0", DataType::Dictionary(
4018+
Box::new(DataType::Int32),
4019+
Box::new(DataType::Utf8),
4020+
), true),
4021+
Field::new("c1", DataType::Int64, true),
4022+
]);
4023+
let dict_struct_type = DataType::Struct(dict_struct_fields.clone());
4024+
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
4025+
let pairs: Vec<(FieldRef, ArrayRef)> =
4026+
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
4027+
Arc::new(StructArray::from(pairs))
4028+
};
4029+
assert_eq!(expected, eval_in_list_from_array(
4030+
dict_struct_type,
4031+
make_dict_struct(
4032+
wrap_in_dict(Arc::clone(&utf8_needle)),
4033+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4034+
),
4035+
make_dict_struct(
4036+
wrap_in_dict(Arc::clone(&utf8_in)),
4037+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4038+
),
4039+
)?);
4040+
4041+
Ok(())
4042+
}
4043+
4044+
#[test]
4045+
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
4046+
use arrow::compute::cast;
4047+
4048+
// Utf8 needle, Dict(Utf8) in_array
4049+
let err = eval_in_list_from_array(
4050+
DataType::Utf8,
4051+
Arc::new(StringArray::from(vec!["a", "d", "b"])),
4052+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4053+
).unwrap_err().to_string();
4054+
assert!(err.contains("Can't compare arrays of different types"), "{err}");
4055+
4056+
// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4057+
// rejects the Utf8 dictionary values at construction time
4058+
let err = eval_in_list_from_array(
4059+
DataType::Dictionary(
4060+
Box::new(DataType::Int32),
4061+
Box::new(DataType::Utf8),
4062+
),
4063+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
4064+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4065+
).unwrap_err().to_string();
4066+
assert!(err.contains("Failed to downcast"), "{err}");
4067+
4068+
// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4069+
// value types, make_comparator rejects the comparison
4070+
let err = eval_in_list_from_array(
4071+
DataType::Dictionary(
4072+
Box::new(DataType::Int32),
4073+
Box::new(DataType::Int64),
4074+
),
4075+
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
4076+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4077+
).unwrap_err().to_string();
4078+
assert!(err.contains("Can't compare arrays of different types"), "{err}");
4079+
4080+
Ok(())
4081+
}
38814082
}

datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug
918918
STORED AS PARQUET
919919
LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet';
920920

921-
query error Can't compare arrays of different types
921+
query TR
922922
SELECT t.tag1, t.value
923923
FROM dict_filter_bug t
924924
JOIN (VALUES ('A'), ('B')) AS v(c1)
925925
ON t.tag1 = v.c1
926926
ORDER BY t.tag1, t.value
927927
LIMIT 4;
928+
----
929+
A 0
930+
A 26
931+
A 52
932+
A 78
928933

929934
# Cleanup
930935
statement ok

0 commit comments

Comments
 (0)