@@ -819,7 +819,8 @@ mod tests {
819819 ] ) ?;
820820
821821 let col_a = col ( "a" , & schema) ?;
822- let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ array. clone( ) ] ) ?;
822+ let batch =
823+ RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ array. clone( ) ] ) ?;
823824
824825 // Helper to format SQL-like representation for error messages
825826 let _format_sql = |negated : bool , with_null : bool | -> String {
@@ -939,7 +940,11 @@ mod tests {
939940
940941 run_test_cases ( vec ! [
941942 primitive_test_case( "utf8" , ScalarValue :: Utf8 , string_data. clone( ) ) ,
942- primitive_test_case( "large_utf8" , ScalarValue :: LargeUtf8 , string_data. clone( ) ) ,
943+ primitive_test_case(
944+ "large_utf8" ,
945+ ScalarValue :: LargeUtf8 ,
946+ string_data. clone( ) ,
947+ ) ,
943948 primitive_test_case( "utf8_view" , ScalarValue :: Utf8View , string_data) ,
944949 ] )
945950 }
@@ -957,7 +962,11 @@ mod tests {
957962
958963 run_test_cases ( vec ! [
959964 primitive_test_case( "binary" , ScalarValue :: Binary , binary_data. clone( ) ) ,
960- primitive_test_case( "large_binary" , ScalarValue :: LargeBinary , binary_data. clone( ) ) ,
965+ primitive_test_case(
966+ "large_binary" ,
967+ ScalarValue :: LargeBinary ,
968+ binary_data. clone( ) ,
969+ ) ,
961970 primitive_test_case( "binary_view" , ScalarValue :: BinaryView , binary_data) ,
962971 ] )
963972 }
@@ -2596,4 +2605,257 @@ mod tests {
25962605 assert_eq ! ( result, & BooleanArray :: from( vec![ true , true , false ] ) ) ;
25972606 Ok ( ( ) )
25982607 }
2608+
2609+ #[ test]
2610+ fn test_in_list_dictionary_types ( ) -> Result < ( ) > {
2611+ // Helper functions for creating dictionary literals
2612+ fn dict_lit_int64 ( key_type : DataType , value : i64 ) -> Arc < dyn PhysicalExpr > {
2613+ lit ( ScalarValue :: Dictionary (
2614+ Box :: new ( key_type) ,
2615+ Box :: new ( ScalarValue :: Int64 ( Some ( value) ) ) ,
2616+ ) )
2617+ }
2618+
2619+ fn dict_lit_float64 ( key_type : DataType , value : f64 ) -> Arc < dyn PhysicalExpr > {
2620+ lit ( ScalarValue :: Dictionary (
2621+ Box :: new ( key_type) ,
2622+ Box :: new ( ScalarValue :: Float64 ( Some ( value) ) ) ,
2623+ ) )
2624+ }
2625+
2626+ // Test case structures
2627+ struct DictNeedleTest {
2628+ list_values : Vec < Arc < dyn PhysicalExpr > > ,
2629+ expected : Vec < Option < bool > > ,
2630+ }
2631+
2632+ struct DictionaryInListTestCase {
2633+ _name : & ' static str ,
2634+ dict_type : DataType ,
2635+ dict_keys : Vec < Option < i8 > > ,
2636+ dict_values : ArrayRef ,
2637+ list_values_no_null : Vec < Arc < dyn PhysicalExpr > > ,
2638+ list_values_with_null : Vec < Arc < dyn PhysicalExpr > > ,
2639+ expected_1 : Vec < Option < bool > > ,
2640+ expected_2 : Vec < Option < bool > > ,
2641+ expected_3 : Vec < Option < bool > > ,
2642+ expected_4 : Vec < Option < bool > > ,
2643+ dict_needle_test : Option < DictNeedleTest > ,
2644+ }
2645+
2646+ // Test harness function
2647+ fn run_dictionary_in_list_test (
2648+ test_case : DictionaryInListTestCase ,
2649+ ) -> Result < ( ) > {
2650+ // Create schema with dictionary type
2651+ let schema =
2652+ Schema :: new ( vec ! [ Field :: new( "a" , test_case. dict_type. clone( ) , true ) ] ) ;
2653+ let col_a = col ( "a" , & schema) ?;
2654+
2655+ // Create dictionary array from keys and values
2656+ let keys = Int8Array :: from ( test_case. dict_keys . clone ( ) ) ;
2657+ let dict_array: ArrayRef =
2658+ Arc :: new ( DictionaryArray :: try_new ( keys, test_case. dict_values ) ?) ;
2659+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ dict_array] ) ?;
2660+
2661+ let exp1 = test_case. expected_1 . clone ( ) ;
2662+ let exp2 = test_case. expected_2 . clone ( ) ;
2663+ let exp3 = test_case. expected_3 . clone ( ) ;
2664+ let exp4 = test_case. expected_4 ;
2665+
2666+ // Test 1: a IN (values_no_null)
2667+ in_list ! (
2668+ batch,
2669+ test_case. list_values_no_null. clone( ) ,
2670+ & false ,
2671+ exp1,
2672+ Arc :: clone( & col_a) ,
2673+ & schema
2674+ ) ;
2675+
2676+ // Test 2: a NOT IN (values_no_null)
2677+ in_list ! (
2678+ batch,
2679+ test_case. list_values_no_null. clone( ) ,
2680+ & true ,
2681+ exp2,
2682+ Arc :: clone( & col_a) ,
2683+ & schema
2684+ ) ;
2685+
2686+ // Test 3: a IN (values_with_null)
2687+ in_list ! (
2688+ batch,
2689+ test_case. list_values_with_null. clone( ) ,
2690+ & false ,
2691+ exp3,
2692+ Arc :: clone( & col_a) ,
2693+ & schema
2694+ ) ;
2695+
2696+ // Test 4: a NOT IN (values_with_null)
2697+ in_list ! (
2698+ batch,
2699+ test_case. list_values_with_null,
2700+ & true ,
2701+ exp4,
2702+ Arc :: clone( & col_a) ,
2703+ & schema
2704+ ) ;
2705+
2706+ // Optional: Dictionary needle test (if provided)
2707+ if let Some ( needle_test) = test_case. dict_needle_test {
2708+ in_list_raw ! (
2709+ batch,
2710+ needle_test. list_values,
2711+ & false ,
2712+ needle_test. expected,
2713+ Arc :: clone( & col_a) ,
2714+ & schema
2715+ ) ;
2716+ }
2717+
2718+ Ok ( ( ) )
2719+ }
2720+
2721+ // Test case 1: UTF8
2722+ // Dictionary: keys [0, 1, null] → values ["a", "d", -]
2723+ // Rows: ["a", "d", null]
2724+ let utf8_case = DictionaryInListTestCase {
2725+ _name : "dictionary_utf8" ,
2726+ dict_type : DataType :: Dictionary (
2727+ Box :: new ( DataType :: Int8 ) ,
2728+ Box :: new ( DataType :: Utf8 ) ,
2729+ ) ,
2730+ dict_keys : vec ! [ Some ( 0 ) , Some ( 1 ) , None ] ,
2731+ dict_values : Arc :: new ( StringArray :: from ( vec ! [ Some ( "a" ) , Some ( "d" ) ] ) ) ,
2732+ list_values_no_null : vec ! [ lit( "a" ) , lit( "b" ) ] ,
2733+ list_values_with_null : vec ! [ lit( "a" ) , lit( "b" ) , lit( ScalarValue :: Utf8 ( None ) ) ] ,
2734+ expected_1 : vec ! [ Some ( true ) , Some ( false ) , None ] ,
2735+ expected_2 : vec ! [ Some ( false ) , Some ( true ) , None ] ,
2736+ expected_3 : vec ! [ Some ( true ) , None , None ] ,
2737+ expected_4 : vec ! [ Some ( false ) , None , None ] ,
2738+ dict_needle_test : None ,
2739+ } ;
2740+
2741+ // Test case 2: Int64 with dictionary needles
2742+ // Dictionary: keys [0, 1, null] → values [10, 20, -]
2743+ // Rows: [10, 20, null]
2744+ let int64_case = DictionaryInListTestCase {
2745+ _name : "dictionary_int64" ,
2746+ dict_type : DataType :: Dictionary (
2747+ Box :: new ( DataType :: Int8 ) ,
2748+ Box :: new ( DataType :: Int64 ) ,
2749+ ) ,
2750+ dict_keys : vec ! [ Some ( 0 ) , Some ( 1 ) , None ] ,
2751+ dict_values : Arc :: new ( Int64Array :: from ( vec ! [ Some ( 10 ) , Some ( 20 ) ] ) ) ,
2752+ list_values_no_null : vec ! [ lit( 10i64 ) , lit( 15i64 ) ] ,
2753+ list_values_with_null : vec ! [
2754+ lit( 10i64 ) ,
2755+ lit( 15i64 ) ,
2756+ lit( ScalarValue :: Int64 ( None ) ) ,
2757+ ] ,
2758+ expected_1 : vec ! [ Some ( true ) , Some ( false ) , None ] ,
2759+ expected_2 : vec ! [ Some ( false ) , Some ( true ) , None ] ,
2760+ expected_3 : vec ! [ Some ( true ) , None , None ] ,
2761+ expected_4 : vec ! [ Some ( false ) , None , None ] ,
2762+ dict_needle_test : Some ( DictNeedleTest {
2763+ list_values : vec ! [
2764+ dict_lit_int64( DataType :: Int16 , 10 ) ,
2765+ dict_lit_int64( DataType :: Int16 , 15 ) ,
2766+ ] ,
2767+ expected : vec ! [ Some ( true ) , Some ( false ) , None ] ,
2768+ } ) ,
2769+ } ;
2770+
2771+ // Test case 3: Float64 with NaN and dictionary needles
2772+ // Dictionary: keys [0, 1, null, 2] → values [1.5, 3.7, NaN, -]
2773+ // Rows: [1.5, 3.7, null, NaN]
2774+ // Note: NaN is a value (not null), so it goes in the values array
2775+ let float64_case = DictionaryInListTestCase {
2776+ _name : "dictionary_float64" ,
2777+ dict_type : DataType :: Dictionary (
2778+ Box :: new ( DataType :: Int8 ) ,
2779+ Box :: new ( DataType :: Float64 ) ,
2780+ ) ,
2781+ dict_keys : vec ! [ Some ( 0 ) , Some ( 1 ) , None , Some ( 2 ) ] ,
2782+ dict_values : Arc :: new ( Float64Array :: from ( vec ! [
2783+ Some ( 1.5 ) , // index 0
2784+ Some ( 3.7 ) , // index 1
2785+ Some ( f64 :: NAN ) , // index 2
2786+ ] ) ) ,
2787+ list_values_no_null : vec ! [ lit( 1.5f64 ) , lit( 2.0f64 ) ] ,
2788+ list_values_with_null : vec ! [
2789+ lit( 1.5f64 ) ,
2790+ lit( 2.0f64 ) ,
2791+ lit( ScalarValue :: Float64 ( None ) ) ,
2792+ ] ,
2793+ // Test 1: a IN (1.5, 2.0) → [true, false, null, false]
2794+ // NaN is false because NaN not in list and no NULL in list
2795+ expected_1 : vec ! [ Some ( true ) , Some ( false ) , None , Some ( false ) ] ,
2796+ // Test 2: a NOT IN (1.5, 2.0) → [false, true, null, true]
2797+ // NaN is true because NaN not in list
2798+ expected_2 : vec ! [ Some ( false ) , Some ( true ) , None , Some ( true ) ] ,
2799+ // Test 3: a IN (1.5, 2.0, NULL) → [true, null, null, null]
2800+ // 3.7 and NaN become null due to NULL in list (three-valued logic)
2801+ expected_3 : vec ! [ Some ( true ) , None , None , None ] ,
2802+ // Test 4: a NOT IN (1.5, 2.0, NULL) → [false, null, null, null]
2803+ // 3.7 and NaN become null due to NULL in list
2804+ expected_4 : vec ! [ Some ( false ) , None , None , None ] ,
2805+ dict_needle_test : Some ( DictNeedleTest {
2806+ list_values : vec ! [
2807+ dict_lit_float64( DataType :: UInt16 , 1.5 ) ,
2808+ dict_lit_float64( DataType :: UInt16 , 2.0 ) ,
2809+ ] ,
2810+ expected : vec ! [ Some ( true ) , Some ( false ) , None , Some ( false ) ] ,
2811+ } ) ,
2812+ } ;
2813+
2814+ // Execute all test cases
2815+ run_dictionary_in_list_test ( utf8_case) . map_err ( |e| {
2816+ datafusion_common:: DataFusionError :: Execution ( format ! (
2817+ "Dictionary test failed for UTF8: {}" ,
2818+ e
2819+ ) )
2820+ } ) ?;
2821+
2822+ run_dictionary_in_list_test ( int64_case) . map_err ( |e| {
2823+ datafusion_common:: DataFusionError :: Execution ( format ! (
2824+ "Dictionary test failed for Int64: {}" ,
2825+ e
2826+ ) )
2827+ } ) ?;
2828+
2829+ run_dictionary_in_list_test ( float64_case) . map_err ( |e| {
2830+ datafusion_common:: DataFusionError :: Execution ( format ! (
2831+ "Dictionary test failed for Float64: {}" ,
2832+ e
2833+ ) )
2834+ } ) ?;
2835+
2836+ // Additional test for Float64 NaN in IN list
2837+ let dict_type =
2838+ DataType :: Dictionary ( Box :: new ( DataType :: Int8 ) , Box :: new ( DataType :: Float64 ) ) ;
2839+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , dict_type. clone( ) , true ) ] ) ;
2840+ let col_a = col ( "a" , & schema) ?;
2841+
2842+ let keys = Int8Array :: from ( vec ! [ Some ( 0 ) , Some ( 1 ) , None , Some ( 2 ) ] ) ;
2843+ let values = Float64Array :: from ( vec ! [ Some ( 1.5 ) , Some ( 3.7 ) , Some ( f64 :: NAN ) ] ) ;
2844+ let dict_array: ArrayRef =
2845+ Arc :: new ( DictionaryArray :: try_new ( keys, Arc :: new ( values) ) ?) ;
2846+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ dict_array] ) ?;
2847+
2848+ // Test: a IN (1.5, 2.0, NaN)
2849+ let list_with_nan = vec ! [ lit( 1.5f64 ) , lit( 2.0f64 ) , lit( f64 :: NAN ) ] ;
2850+ in_list ! (
2851+ batch,
2852+ list_with_nan,
2853+ & false ,
2854+ vec![ Some ( true ) , Some ( false ) , None , Some ( true ) ] ,
2855+ col_a,
2856+ & schema
2857+ ) ;
2858+
2859+ Ok ( ( ) )
2860+ }
25992861}
0 commit comments