@@ -99,11 +99,16 @@ impl StaticFilter for ArrayStaticFilter {
9999 ) ) ;
100100 }
101101
102+ // Unwrap dictionary-encoded needles when the value type matches
103+ // in_array, evaluating against distinct values and mapping back
104+ // via keys.
102105 downcast_dictionary_array ! {
103106 v => {
104- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
105- let result = take( & values_contains, v. keys( ) , None ) ?;
106- return Ok ( downcast_array( result. as_ref( ) ) )
107+ if v. values( ) . data_type( ) == self . in_array. data_type( ) {
108+ let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
109+ let result = take( & values_contains, v. keys( ) , None ) ?;
110+ return Ok ( downcast_array( result. as_ref( ) ) ) ;
111+ }
107112 }
108113 _ => { }
109114 }
@@ -3878,4 +3883,200 @@ mod tests {
38783883 ) ;
38793884 Ok ( ( ) )
38803885 }
3886+
3887+ // -----------------------------------------------------------------------
3888+ // Tests for try_new_from_array covering all (in_array, needle) type
3889+ // combinations that occur in HashJoin dynamic filter pushdown.
3890+ //
3891+ // try_new (used by SQL IN expressions) always produces a non-Dictionary
3892+ // in_array because evaluate_list() flattens Dictionary scalars to their
3893+ // value type. try_new_from_array passes the array directly, so it is
3894+ // the only path that can produce a Dictionary in_array.
3895+ // -----------------------------------------------------------------------
3896+
3897+ fn wrap_in_dict ( array : ArrayRef ) -> ArrayRef {
3898+ let keys = Int32Array :: from ( ( 0 ..array. len ( ) as i32 ) . collect :: < Vec < _ > > ( ) ) ;
3899+ Arc :: new ( DictionaryArray :: new ( keys, array) )
3900+ }
3901+
3902+ fn eval_in_list_from_array (
3903+ needle_type : DataType ,
3904+ needle : ArrayRef ,
3905+ in_array : ArrayRef ,
3906+ ) -> Result < BooleanArray > {
3907+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , needle_type, false ) ] ) ;
3908+ let col_a = col ( "a" , & schema) ?;
3909+ let expr = Arc :: new ( InListExpr :: try_new_from_array (
3910+ col_a, in_array, false ,
3911+ ) ?) as Arc < dyn PhysicalExpr > ;
3912+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ needle] ) ?;
3913+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3914+ Ok ( as_boolean_array ( & result) . clone ( ) )
3915+ }
3916+
3917+ #[ test]
3918+ fn test_in_list_from_array_type_combinations ( ) -> Result < ( ) > {
3919+ use arrow:: compute:: cast;
3920+
3921+ // All cases: needle[0] and needle[2] match, needle[1] does not.
3922+ let expected =
3923+ BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , Some ( true ) ] ) ;
3924+
3925+ // Base arrays cast to each target type
3926+ let base_in = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 2 , 3 ] ) ) as ArrayRef ;
3927+ let base_needle = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 4 , 2 ] ) ) as ArrayRef ;
3928+
3929+ // Test all specializations in instantiate_static_filter
3930+ let primitive_types = vec ! [
3931+ DataType :: Int8 ,
3932+ DataType :: Int16 ,
3933+ DataType :: Int32 ,
3934+ DataType :: Int64 ,
3935+ DataType :: UInt8 ,
3936+ DataType :: UInt16 ,
3937+ DataType :: UInt32 ,
3938+ DataType :: UInt64 ,
3939+ DataType :: Float32 ,
3940+ DataType :: Float64 ,
3941+ ] ;
3942+
3943+ for dt in & primitive_types {
3944+ let in_array = cast ( & base_in, dt) ?;
3945+ let needle = cast ( & base_needle, dt) ?;
3946+
3947+ // T in_array, T needle
3948+ assert_eq ! ( expected, eval_in_list_from_array(
3949+ dt. clone( ) , Arc :: clone( & needle) , Arc :: clone( & in_array) ,
3950+ ) ?, "same-type failed for {dt:?}" ) ;
3951+
3952+ // T in_array, Dict(Int32, T) needle
3953+ let dict_dt = DataType :: Dictionary (
3954+ Box :: new ( DataType :: Int32 ) ,
3955+ Box :: new ( dt. clone ( ) ) ,
3956+ ) ;
3957+ assert_eq ! ( expected, eval_in_list_from_array(
3958+ dict_dt, wrap_in_dict( needle) , in_array,
3959+ ) ?, "dict-needle failed for {dt:?}" ) ;
3960+ }
3961+
3962+ // Utf8 (falls through to ArrayStaticFilter)
3963+ let utf8_in = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) as ArrayRef ;
3964+ let utf8_needle =
3965+ Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) as ArrayRef ;
3966+ let dict_utf8 = DataType :: Dictionary (
3967+ Box :: new ( DataType :: Int32 ) ,
3968+ Box :: new ( DataType :: Utf8 ) ,
3969+ ) ;
3970+
3971+ // Utf8 in_array, Utf8 needle
3972+ assert_eq ! ( expected, eval_in_list_from_array(
3973+ DataType :: Utf8 ,
3974+ Arc :: clone( & utf8_needle) ,
3975+ Arc :: clone( & utf8_in) ,
3976+ ) ?) ;
3977+
3978+ // Utf8 in_array, Dict(Utf8) needle
3979+ assert_eq ! ( expected, eval_in_list_from_array(
3980+ dict_utf8. clone( ) ,
3981+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3982+ Arc :: clone( & utf8_in) ,
3983+ ) ?) ;
3984+
3985+ // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3986+ assert_eq ! ( expected, eval_in_list_from_array(
3987+ dict_utf8,
3988+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3989+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
3990+ ) ?) ;
3991+
3992+ // Struct in_array, Struct needle: multi-column join
3993+ let struct_fields = Fields :: from ( vec ! [
3994+ Field :: new( "c0" , DataType :: Utf8 , true ) ,
3995+ Field :: new( "c1" , DataType :: Int64 , true ) ,
3996+ ] ) ;
3997+ let struct_type = DataType :: Struct ( struct_fields. clone ( ) ) ;
3998+ let make_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
3999+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
4000+ struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
4001+ Arc :: new ( StructArray :: from ( pairs) )
4002+ } ;
4003+ assert_eq ! ( expected, eval_in_list_from_array(
4004+ struct_type,
4005+ make_struct(
4006+ Arc :: clone( & utf8_needle) ,
4007+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4008+ ) ,
4009+ make_struct(
4010+ Arc :: clone( & utf8_in) ,
4011+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4012+ ) ,
4013+ ) ?) ;
4014+
4015+ // Struct with Dict fields: multi-column Dict join
4016+ let dict_struct_fields = Fields :: from ( vec ! [
4017+ Field :: new( "c0" , DataType :: Dictionary (
4018+ Box :: new( DataType :: Int32 ) ,
4019+ Box :: new( DataType :: Utf8 ) ,
4020+ ) , true ) ,
4021+ Field :: new( "c1" , DataType :: Int64 , true ) ,
4022+ ] ) ;
4023+ let dict_struct_type = DataType :: Struct ( dict_struct_fields. clone ( ) ) ;
4024+ let make_dict_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
4025+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
4026+ dict_struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
4027+ Arc :: new ( StructArray :: from ( pairs) )
4028+ } ;
4029+ assert_eq ! ( expected, eval_in_list_from_array(
4030+ dict_struct_type,
4031+ make_dict_struct(
4032+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
4033+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4034+ ) ,
4035+ make_dict_struct(
4036+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
4037+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4038+ ) ,
4039+ ) ?) ;
4040+
4041+ Ok ( ( ) )
4042+ }
4043+
4044+ #[ test]
4045+ fn test_in_list_from_array_type_mismatch_errors ( ) -> Result < ( ) > {
4046+ use arrow:: compute:: cast;
4047+
4048+ // Utf8 needle, Dict(Utf8) in_array
4049+ let err = eval_in_list_from_array (
4050+ DataType :: Utf8 ,
4051+ Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ,
4052+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4053+ ) . unwrap_err ( ) . to_string ( ) ;
4054+ assert ! ( err. contains( "Can't compare arrays of different types" ) , "{err}" ) ;
4055+
4056+ // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4057+ // rejects the Utf8 dictionary values at construction time
4058+ let err = eval_in_list_from_array (
4059+ DataType :: Dictionary (
4060+ Box :: new ( DataType :: Int32 ) ,
4061+ Box :: new ( DataType :: Utf8 ) ,
4062+ ) ,
4063+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ) ,
4064+ Arc :: new ( Int64Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ,
4065+ ) . unwrap_err ( ) . to_string ( ) ;
4066+ assert ! ( err. contains( "Failed to downcast" ) , "{err}" ) ;
4067+
4068+ // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4069+ // value types, make_comparator rejects the comparison
4070+ let err = eval_in_list_from_array (
4071+ DataType :: Dictionary (
4072+ Box :: new ( DataType :: Int32 ) ,
4073+ Box :: new ( DataType :: Int64 ) ,
4074+ ) ,
4075+ wrap_in_dict ( Arc :: new ( Int64Array :: from ( vec ! [ 1 , 4 , 2 ] ) ) ) ,
4076+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4077+ ) . unwrap_err ( ) . to_string ( ) ;
4078+ assert ! ( err. contains( "Can't compare arrays of different types" ) , "{err}" ) ;
4079+
4080+ Ok ( ( ) )
4081+ }
38814082}
0 commit comments