Skip to content

Commit 9996215

Browse files
committed
add dictionary tests
1 parent d7f4ae6 commit 9996215

File tree

1 file changed

+265
-3
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+265
-3
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 265 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,8 @@ mod tests {
819819
])?;
820820

821821
let col_a = col("a", &schema)?;
822-
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?;
822+
let batch =
823+
RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?;
823824

824825
// Helper to format SQL-like representation for error messages
825826
let _format_sql = |negated: bool, with_null: bool| -> String {
@@ -939,7 +940,11 @@ mod tests {
939940

940941
run_test_cases(vec![
941942
primitive_test_case("utf8", ScalarValue::Utf8, string_data.clone()),
942-
primitive_test_case("large_utf8", ScalarValue::LargeUtf8, string_data.clone()),
943+
primitive_test_case(
944+
"large_utf8",
945+
ScalarValue::LargeUtf8,
946+
string_data.clone(),
947+
),
943948
primitive_test_case("utf8_view", ScalarValue::Utf8View, string_data),
944949
])
945950
}
@@ -957,7 +962,11 @@ mod tests {
957962

958963
run_test_cases(vec![
959964
primitive_test_case("binary", ScalarValue::Binary, binary_data.clone()),
960-
primitive_test_case("large_binary", ScalarValue::LargeBinary, binary_data.clone()),
965+
primitive_test_case(
966+
"large_binary",
967+
ScalarValue::LargeBinary,
968+
binary_data.clone(),
969+
),
961970
primitive_test_case("binary_view", ScalarValue::BinaryView, binary_data),
962971
])
963972
}
@@ -2596,4 +2605,257 @@ mod tests {
25962605
assert_eq!(result, &BooleanArray::from(vec![true, true, false]));
25972606
Ok(())
25982607
}
2608+
2609+
#[test]
2610+
fn test_in_list_dictionary_types() -> Result<()> {
2611+
// Helper functions for creating dictionary literals
2612+
fn dict_lit_int64(key_type: DataType, value: i64) -> Arc<dyn PhysicalExpr> {
2613+
lit(ScalarValue::Dictionary(
2614+
Box::new(key_type),
2615+
Box::new(ScalarValue::Int64(Some(value))),
2616+
))
2617+
}
2618+
2619+
fn dict_lit_float64(key_type: DataType, value: f64) -> Arc<dyn PhysicalExpr> {
2620+
lit(ScalarValue::Dictionary(
2621+
Box::new(key_type),
2622+
Box::new(ScalarValue::Float64(Some(value))),
2623+
))
2624+
}
2625+
2626+
// Test case structures
2627+
struct DictNeedleTest {
2628+
list_values: Vec<Arc<dyn PhysicalExpr>>,
2629+
expected: Vec<Option<bool>>,
2630+
}
2631+
2632+
struct DictionaryInListTestCase {
2633+
_name: &'static str,
2634+
dict_type: DataType,
2635+
dict_keys: Vec<Option<i8>>,
2636+
dict_values: ArrayRef,
2637+
list_values_no_null: Vec<Arc<dyn PhysicalExpr>>,
2638+
list_values_with_null: Vec<Arc<dyn PhysicalExpr>>,
2639+
expected_1: Vec<Option<bool>>,
2640+
expected_2: Vec<Option<bool>>,
2641+
expected_3: Vec<Option<bool>>,
2642+
expected_4: Vec<Option<bool>>,
2643+
dict_needle_test: Option<DictNeedleTest>,
2644+
}
2645+
2646+
// Test harness function
2647+
fn run_dictionary_in_list_test(
2648+
test_case: DictionaryInListTestCase,
2649+
) -> Result<()> {
2650+
// Create schema with dictionary type
2651+
let schema =
2652+
Schema::new(vec![Field::new("a", test_case.dict_type.clone(), true)]);
2653+
let col_a = col("a", &schema)?;
2654+
2655+
// Create dictionary array from keys and values
2656+
let keys = Int8Array::from(test_case.dict_keys.clone());
2657+
let dict_array: ArrayRef =
2658+
Arc::new(DictionaryArray::try_new(keys, test_case.dict_values)?);
2659+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?;
2660+
2661+
let exp1 = test_case.expected_1.clone();
2662+
let exp2 = test_case.expected_2.clone();
2663+
let exp3 = test_case.expected_3.clone();
2664+
let exp4 = test_case.expected_4;
2665+
2666+
// Test 1: a IN (values_no_null)
2667+
in_list!(
2668+
batch,
2669+
test_case.list_values_no_null.clone(),
2670+
&false,
2671+
exp1,
2672+
Arc::clone(&col_a),
2673+
&schema
2674+
);
2675+
2676+
// Test 2: a NOT IN (values_no_null)
2677+
in_list!(
2678+
batch,
2679+
test_case.list_values_no_null.clone(),
2680+
&true,
2681+
exp2,
2682+
Arc::clone(&col_a),
2683+
&schema
2684+
);
2685+
2686+
// Test 3: a IN (values_with_null)
2687+
in_list!(
2688+
batch,
2689+
test_case.list_values_with_null.clone(),
2690+
&false,
2691+
exp3,
2692+
Arc::clone(&col_a),
2693+
&schema
2694+
);
2695+
2696+
// Test 4: a NOT IN (values_with_null)
2697+
in_list!(
2698+
batch,
2699+
test_case.list_values_with_null,
2700+
&true,
2701+
exp4,
2702+
Arc::clone(&col_a),
2703+
&schema
2704+
);
2705+
2706+
// Optional: Dictionary needle test (if provided)
2707+
if let Some(needle_test) = test_case.dict_needle_test {
2708+
in_list_raw!(
2709+
batch,
2710+
needle_test.list_values,
2711+
&false,
2712+
needle_test.expected,
2713+
Arc::clone(&col_a),
2714+
&schema
2715+
);
2716+
}
2717+
2718+
Ok(())
2719+
}
2720+
2721+
// Test case 1: UTF8
2722+
// Dictionary: keys [0, 1, null] → values ["a", "d", -]
2723+
// Rows: ["a", "d", null]
2724+
let utf8_case = DictionaryInListTestCase {
2725+
_name: "dictionary_utf8",
2726+
dict_type: DataType::Dictionary(
2727+
Box::new(DataType::Int8),
2728+
Box::new(DataType::Utf8),
2729+
),
2730+
dict_keys: vec![Some(0), Some(1), None],
2731+
dict_values: Arc::new(StringArray::from(vec![Some("a"), Some("d")])),
2732+
list_values_no_null: vec![lit("a"), lit("b")],
2733+
list_values_with_null: vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
2734+
expected_1: vec![Some(true), Some(false), None],
2735+
expected_2: vec![Some(false), Some(true), None],
2736+
expected_3: vec![Some(true), None, None],
2737+
expected_4: vec![Some(false), None, None],
2738+
dict_needle_test: None,
2739+
};
2740+
2741+
// Test case 2: Int64 with dictionary needles
2742+
// Dictionary: keys [0, 1, null] → values [10, 20, -]
2743+
// Rows: [10, 20, null]
2744+
let int64_case = DictionaryInListTestCase {
2745+
_name: "dictionary_int64",
2746+
dict_type: DataType::Dictionary(
2747+
Box::new(DataType::Int8),
2748+
Box::new(DataType::Int64),
2749+
),
2750+
dict_keys: vec![Some(0), Some(1), None],
2751+
dict_values: Arc::new(Int64Array::from(vec![Some(10), Some(20)])),
2752+
list_values_no_null: vec![lit(10i64), lit(15i64)],
2753+
list_values_with_null: vec![
2754+
lit(10i64),
2755+
lit(15i64),
2756+
lit(ScalarValue::Int64(None)),
2757+
],
2758+
expected_1: vec![Some(true), Some(false), None],
2759+
expected_2: vec![Some(false), Some(true), None],
2760+
expected_3: vec![Some(true), None, None],
2761+
expected_4: vec![Some(false), None, None],
2762+
dict_needle_test: Some(DictNeedleTest {
2763+
list_values: vec![
2764+
dict_lit_int64(DataType::Int16, 10),
2765+
dict_lit_int64(DataType::Int16, 15),
2766+
],
2767+
expected: vec![Some(true), Some(false), None],
2768+
}),
2769+
};
2770+
2771+
// Test case 3: Float64 with NaN and dictionary needles
2772+
// Dictionary: keys [0, 1, null, 2] → values [1.5, 3.7, NaN, -]
2773+
// Rows: [1.5, 3.7, null, NaN]
2774+
// Note: NaN is a value (not null), so it goes in the values array
2775+
let float64_case = DictionaryInListTestCase {
2776+
_name: "dictionary_float64",
2777+
dict_type: DataType::Dictionary(
2778+
Box::new(DataType::Int8),
2779+
Box::new(DataType::Float64),
2780+
),
2781+
dict_keys: vec![Some(0), Some(1), None, Some(2)],
2782+
dict_values: Arc::new(Float64Array::from(vec![
2783+
Some(1.5), // index 0
2784+
Some(3.7), // index 1
2785+
Some(f64::NAN), // index 2
2786+
])),
2787+
list_values_no_null: vec![lit(1.5f64), lit(2.0f64)],
2788+
list_values_with_null: vec![
2789+
lit(1.5f64),
2790+
lit(2.0f64),
2791+
lit(ScalarValue::Float64(None)),
2792+
],
2793+
// Test 1: a IN (1.5, 2.0) → [true, false, null, false]
2794+
// NaN is false because NaN not in list and no NULL in list
2795+
expected_1: vec![Some(true), Some(false), None, Some(false)],
2796+
// Test 2: a NOT IN (1.5, 2.0) → [false, true, null, true]
2797+
// NaN is true because NaN not in list
2798+
expected_2: vec![Some(false), Some(true), None, Some(true)],
2799+
// Test 3: a IN (1.5, 2.0, NULL) → [true, null, null, null]
2800+
// 3.7 and NaN become null due to NULL in list (three-valued logic)
2801+
expected_3: vec![Some(true), None, None, None],
2802+
// Test 4: a NOT IN (1.5, 2.0, NULL) → [false, null, null, null]
2803+
// 3.7 and NaN become null due to NULL in list
2804+
expected_4: vec![Some(false), None, None, None],
2805+
dict_needle_test: Some(DictNeedleTest {
2806+
list_values: vec![
2807+
dict_lit_float64(DataType::UInt16, 1.5),
2808+
dict_lit_float64(DataType::UInt16, 2.0),
2809+
],
2810+
expected: vec![Some(true), Some(false), None, Some(false)],
2811+
}),
2812+
};
2813+
2814+
// Execute all test cases
2815+
run_dictionary_in_list_test(utf8_case).map_err(|e| {
2816+
datafusion_common::DataFusionError::Execution(format!(
2817+
"Dictionary test failed for UTF8: {}",
2818+
e
2819+
))
2820+
})?;
2821+
2822+
run_dictionary_in_list_test(int64_case).map_err(|e| {
2823+
datafusion_common::DataFusionError::Execution(format!(
2824+
"Dictionary test failed for Int64: {}",
2825+
e
2826+
))
2827+
})?;
2828+
2829+
run_dictionary_in_list_test(float64_case).map_err(|e| {
2830+
datafusion_common::DataFusionError::Execution(format!(
2831+
"Dictionary test failed for Float64: {}",
2832+
e
2833+
))
2834+
})?;
2835+
2836+
// Additional test for Float64 NaN in IN list
2837+
let dict_type =
2838+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Float64));
2839+
let schema = Schema::new(vec![Field::new("a", dict_type.clone(), true)]);
2840+
let col_a = col("a", &schema)?;
2841+
2842+
let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]);
2843+
let values = Float64Array::from(vec![Some(1.5), Some(3.7), Some(f64::NAN)]);
2844+
let dict_array: ArrayRef =
2845+
Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?);
2846+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?;
2847+
2848+
// Test: a IN (1.5, 2.0, NaN)
2849+
let list_with_nan = vec![lit(1.5f64), lit(2.0f64), lit(f64::NAN)];
2850+
in_list!(
2851+
batch,
2852+
list_with_nan,
2853+
&false,
2854+
vec![Some(true), Some(false), None, Some(true)],
2855+
col_a,
2856+
&schema
2857+
);
2858+
2859+
Ok(())
2860+
}
25992861
}

0 commit comments

Comments
 (0)